mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] add dealer manager to reuse the connection (#3471)
* [BugFix] fix control signal release failed * [BugFix] fix control signal release failed * update * update * update * [Feature] add dealer manager to reuse the connection * fix * fix * fix * fix * fix * fix * Create test_dealer_connection_manager.py * Delete test/entrypoints/openai directory * Update test_dealer_connection_manager.py * Update test_dealer_connection_manager.py
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
@@ -22,6 +23,7 @@ import numpy as np
|
|||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
from fastdeploy.engine.config import ModelConfig
|
from fastdeploy.engine.config import ModelConfig
|
||||||
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
|
||||||
from fastdeploy.input.preprocess import InputPreprocessor
|
from fastdeploy.input.preprocess import InputPreprocessor
|
||||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||||
@@ -91,6 +93,10 @@ class EngineClient:
|
|||||||
suffix=pid,
|
suffix=pid,
|
||||||
create=False,
|
create=False,
|
||||||
)
|
)
|
||||||
|
self.connection_manager = DealerConnectionManager(
|
||||||
|
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
|
||||||
|
)
|
||||||
|
self.connection_initialized = False
|
||||||
|
|
||||||
def create_zmq_client(self, model, mode):
|
def create_zmq_client(self, model, mode):
|
||||||
"""
|
"""
|
||||||
|
@@ -154,6 +154,7 @@ async def lifespan(app: FastAPI):
|
|||||||
yield
|
yield
|
||||||
# close zmq
|
# close zmq
|
||||||
try:
|
try:
|
||||||
|
await engine_client.connection_manager.close()
|
||||||
engine_client.zmq_client.close()
|
engine_client.zmq_client.close()
|
||||||
from prometheus_client import multiprocess
|
from prometheus_client import multiprocess
|
||||||
|
|
||||||
|
@@ -20,10 +20,7 @@ import traceback
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiozmq
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from aiozmq import zmq
|
|
||||||
|
|
||||||
from fastdeploy.entrypoints.openai.protocol import (
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@@ -62,6 +59,12 @@ class OpenAIServingChat:
|
|||||||
else:
|
else:
|
||||||
self.master_ip = self.master_ip.split(",")[0]
|
self.master_ip = self.master_ip.split(",")[0]
|
||||||
|
|
||||||
|
async def _ensure_connection_manager(self):
|
||||||
|
"""ensure connection manager initialized"""
|
||||||
|
if not self.engine_client.connection_initialized:
|
||||||
|
await self.engine_client.connection_manager.initialize()
|
||||||
|
self.engine_client.connection_initialized = True
|
||||||
|
|
||||||
def _check_master(self):
|
def _check_master(self):
|
||||||
if self.master_ip is None:
|
if self.master_ip is None:
|
||||||
return True
|
return True
|
||||||
@@ -180,14 +183,16 @@ class OpenAIServingChat:
|
|||||||
choices=[],
|
choices=[],
|
||||||
model=model_name,
|
model=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
|
await self._ensure_connection_manager()
|
||||||
|
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
|
||||||
dealer.write([b"", request_id.encode("utf-8")])
|
dealer.write([b"", request_id.encode("utf-8")])
|
||||||
choices = []
|
choices = []
|
||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
while num_choices > 0:
|
while num_choices > 0:
|
||||||
try:
|
try:
|
||||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
@@ -202,7 +207,6 @@ class OpenAIServingChat:
|
|||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
continue
|
continue
|
||||||
response = msgpack.unpackb(raw_data[-1])
|
|
||||||
for res in response:
|
for res in response:
|
||||||
if res.get("error_code", 200) != 200:
|
if res.get("error_code", 200) != 200:
|
||||||
raise ValueError("{}".format(res["error_msg"]))
|
raise ValueError("{}".format(res["error_msg"]))
|
||||||
@@ -353,9 +357,9 @@ class OpenAIServingChat:
|
|||||||
)
|
)
|
||||||
yield f"data: {error_data}\n\n"
|
yield f"data: {error_data}\n\n"
|
||||||
finally:
|
finally:
|
||||||
dealer.close()
|
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||||
self.engine_client.semaphore.release()
|
self.engine_client.semaphore.release()
|
||||||
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
|
api_server_logger.info(f"release {request_id} {self.engine_client.semaphore.status()}")
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
async def chat_completion_full_generator(
|
async def chat_completion_full_generator(
|
||||||
@@ -378,7 +382,8 @@ class OpenAIServingChat:
|
|||||||
include_stop_str_in_output = request.include_stop_str_in_output
|
include_stop_str_in_output = request.include_stop_str_in_output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
|
await self._ensure_connection_manager()
|
||||||
|
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
|
||||||
dealer.write([b"", request_id.encode("utf-8")])
|
dealer.write([b"", request_id.encode("utf-8")])
|
||||||
final_res = None
|
final_res = None
|
||||||
previous_num_tokens = 0
|
previous_num_tokens = 0
|
||||||
@@ -387,7 +392,7 @@ class OpenAIServingChat:
|
|||||||
completion_token_ids = []
|
completion_token_ids = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
@@ -400,7 +405,6 @@ class OpenAIServingChat:
|
|||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response = msgpack.unpackb(raw_data[-1])
|
|
||||||
task_is_finished = False
|
task_is_finished = False
|
||||||
for data in response:
|
for data in response:
|
||||||
if data.get("error_code", 200) != 200:
|
if data.get("error_code", 200) != 200:
|
||||||
@@ -430,7 +434,7 @@ class OpenAIServingChat:
|
|||||||
if task_is_finished:
|
if task_is_finished:
|
||||||
break
|
break
|
||||||
finally:
|
finally:
|
||||||
dealer.close()
|
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||||
self.engine_client.semaphore.release()
|
self.engine_client.semaphore.release()
|
||||||
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
|
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
|
||||||
|
|
||||||
|
@@ -20,10 +20,7 @@ import traceback
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiozmq
|
|
||||||
import msgpack
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from aiozmq import zmq
|
|
||||||
|
|
||||||
from fastdeploy.engine.request import RequestOutput
|
from fastdeploy.engine.request import RequestOutput
|
||||||
from fastdeploy.entrypoints.openai.protocol import (
|
from fastdeploy.entrypoints.openai.protocol import (
|
||||||
@@ -53,6 +50,12 @@ class OpenAIServingCompletion:
|
|||||||
else:
|
else:
|
||||||
self.master_ip = self.master_ip.split(",")[0]
|
self.master_ip = self.master_ip.split(",")[0]
|
||||||
|
|
||||||
|
async def _ensure_connection_manager(self):
|
||||||
|
"""ensure connection manager initialized"""
|
||||||
|
if not self.engine_client.connection_initialized:
|
||||||
|
await self.engine_client.connection_manager.initialize()
|
||||||
|
self.engine_client.connection_initialized = True
|
||||||
|
|
||||||
def _check_master(self):
|
def _check_master(self):
|
||||||
if self.master_ip is None:
|
if self.master_ip is None:
|
||||||
return True
|
return True
|
||||||
@@ -185,7 +188,10 @@ class OpenAIServingCompletion:
|
|||||||
try:
|
try:
|
||||||
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
|
request_ids = [f"{request_id}-{i}" for i in range(num_choices)]
|
||||||
# create dealer
|
# create dealer
|
||||||
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
|
await self._ensure_connection_manager()
|
||||||
|
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||||
|
request_id, num_choices
|
||||||
|
)
|
||||||
|
|
||||||
for rid in request_ids:
|
for rid in request_ids:
|
||||||
dealer.write([b"", rid.encode("utf-8")])
|
dealer.write([b"", rid.encode("utf-8")])
|
||||||
@@ -198,7 +204,7 @@ class OpenAIServingCompletion:
|
|||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
while num_choices > 0:
|
while num_choices > 0:
|
||||||
try:
|
try:
|
||||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
@@ -210,7 +216,7 @@ class OpenAIServingCompletion:
|
|||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
response = msgpack.unpackb(raw_data[-1])
|
|
||||||
for data in response:
|
for data in response:
|
||||||
rid = int(data["request_id"].split("-")[-1])
|
rid = int(data["request_id"].split("-")[-1])
|
||||||
if data.get("error_code", 200) != 200:
|
if data.get("error_code", 200) != 200:
|
||||||
@@ -255,7 +261,7 @@ class OpenAIServingCompletion:
|
|||||||
finally:
|
finally:
|
||||||
self.engine_client.semaphore.release()
|
self.engine_client.semaphore.release()
|
||||||
if dealer is not None:
|
if dealer is not None:
|
||||||
dealer.close()
|
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||||
|
|
||||||
async def _echo_back_prompt(self, request, res, idx):
|
async def _echo_back_prompt(self, request, res, idx):
|
||||||
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
|
if res["outputs"].get("send_idx", -1) == 0 and request.echo:
|
||||||
@@ -288,7 +294,10 @@ class OpenAIServingCompletion:
|
|||||||
Process the stream completion request.
|
Process the stream completion request.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=f"ipc:///dev/shm/router_{self.pid}.ipc")
|
await self._ensure_connection_manager()
|
||||||
|
dealer, response_queue = await self.engine_client.connection_manager.get_connection(
|
||||||
|
request_id, num_choices
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(num_choices):
|
for i in range(num_choices):
|
||||||
req_id = f"{request_id}-{i}"
|
req_id = f"{request_id}-{i}"
|
||||||
@@ -312,7 +321,7 @@ class OpenAIServingCompletion:
|
|||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
while num_choices > 0:
|
while num_choices > 0:
|
||||||
try:
|
try:
|
||||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
response = await asyncio.wait_for(response_queue.get(), timeout=10)
|
||||||
current_waiting_time = 0
|
current_waiting_time = 0
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
current_waiting_time += 10
|
current_waiting_time += 10
|
||||||
@@ -325,7 +334,6 @@ class OpenAIServingCompletion:
|
|||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response = msgpack.unpackb(raw_data[-1])
|
|
||||||
for res in response:
|
for res in response:
|
||||||
idx = int(res["request_id"].split("-")[-1])
|
idx = int(res["request_id"].split("-")[-1])
|
||||||
if res.get("error_code", 200) != 200:
|
if res.get("error_code", 200) != 200:
|
||||||
@@ -453,9 +461,9 @@ class OpenAIServingCompletion:
|
|||||||
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
|
yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n"
|
||||||
finally:
|
finally:
|
||||||
del request
|
del request
|
||||||
self.engine_client.semaphore.release()
|
|
||||||
if dealer is not None:
|
if dealer is not None:
|
||||||
dealer.close()
|
await self.engine_client.connection_manager.cleanup_request(request_id)
|
||||||
|
self.engine_client.semaphore.release()
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
def request_output_to_completion_response(
|
def request_output_to_completion_response(
|
||||||
|
159
fastdeploy/entrypoints/openai/utils.py
Normal file
159
fastdeploy/entrypoints/openai/utils.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import heapq
|
||||||
|
import random
|
||||||
|
|
||||||
|
import aiozmq
|
||||||
|
import msgpack
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from fastdeploy.utils import api_server_logger
|
||||||
|
|
||||||
|
|
||||||
|
class DealerConnectionManager:
|
||||||
|
"""
|
||||||
|
Manager for dealer connections, supporting multiplexing and connection reuse
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pid, max_connections=10):
|
||||||
|
self.pid = pid
|
||||||
|
self.max_connections = max(max_connections, 10)
|
||||||
|
self.connections = []
|
||||||
|
self.connection_load = []
|
||||||
|
self.connection_heap = []
|
||||||
|
self.request_map = {} # request_id -> response_queue
|
||||||
|
self.request_num = {} # request_id -> num_choices
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.connection_tasks = []
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
"""initialize all connections"""
|
||||||
|
self.running = True
|
||||||
|
for index in range(self.max_connections):
|
||||||
|
await self._add_connection(index)
|
||||||
|
api_server_logger.info(f"Started {self.max_connections} connections")
|
||||||
|
|
||||||
|
async def _add_connection(self, index):
|
||||||
|
"""create a new connection and start listening task"""
|
||||||
|
try:
|
||||||
|
dealer = await aiozmq.create_zmq_stream(
|
||||||
|
zmq.DEALER,
|
||||||
|
connect=f"ipc:///dev/shm/router_{self.pid}.ipc",
|
||||||
|
)
|
||||||
|
async with self.lock:
|
||||||
|
self.connections.append(dealer)
|
||||||
|
self.connection_load.append(0)
|
||||||
|
heapq.heappush(self.connection_heap, (0, index))
|
||||||
|
|
||||||
|
# start listening
|
||||||
|
task = asyncio.create_task(self._listen_connection(dealer, index))
|
||||||
|
self.connection_tasks.append(task)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
api_server_logger.error(f"Failed to create dealer: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _listen_connection(self, dealer, conn_index):
|
||||||
|
"""
|
||||||
|
listen for messages from the dealer connection
|
||||||
|
"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
raw_data = await dealer.read()
|
||||||
|
response = msgpack.unpackb(raw_data[-1])
|
||||||
|
request_id = response[-1]["request_id"]
|
||||||
|
if "cmpl" == request_id[:4]:
|
||||||
|
request_id = request_id.rsplit("-", 1)[0]
|
||||||
|
async with self.lock:
|
||||||
|
if request_id in self.request_map:
|
||||||
|
await self.request_map[request_id].put(response)
|
||||||
|
if response[-1]["finished"]:
|
||||||
|
self.request_num[request_id] -= 1
|
||||||
|
if self.request_num[request_id] == 0:
|
||||||
|
self._update_load(conn_index, -1)
|
||||||
|
except Exception as e:
|
||||||
|
api_server_logger.error(f"Listener error: {str(e)}")
|
||||||
|
break
|
||||||
|
|
||||||
|
def _update_load(self, conn_index, delta):
|
||||||
|
"""Update connection load and maintain the heap"""
|
||||||
|
self.connection_load[conn_index] += delta
|
||||||
|
heapq.heapify(self.connection_heap)
|
||||||
|
|
||||||
|
# For Debugging purposes
|
||||||
|
if random.random() < 0.01:
|
||||||
|
min_load = self.connection_heap[0][0] if self.connection_heap else 0
|
||||||
|
max_load = max(self.connection_load) if self.connection_load else 0
|
||||||
|
api_server_logger.debug(f"Connection load update: min={min_load}, max={max_load}")
|
||||||
|
|
||||||
|
def _get_least_loaded_connection(self):
|
||||||
|
"""
|
||||||
|
Get the least loaded connection
|
||||||
|
"""
|
||||||
|
if not self.connection_heap:
|
||||||
|
return None
|
||||||
|
|
||||||
|
load, conn_index = self.connection_heap[0]
|
||||||
|
self._update_load(conn_index, 1)
|
||||||
|
|
||||||
|
return self.connections[conn_index]
|
||||||
|
|
||||||
|
async def get_connection(self, request_id, num_choices=1):
|
||||||
|
"""get a connection for the request"""
|
||||||
|
|
||||||
|
response_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
self.request_map[request_id] = response_queue
|
||||||
|
self.request_num[request_id] = num_choices
|
||||||
|
dealer = self._get_least_loaded_connection()
|
||||||
|
if not dealer:
|
||||||
|
raise RuntimeError("No available connections")
|
||||||
|
|
||||||
|
return dealer, response_queue
|
||||||
|
|
||||||
|
async def cleanup_request(self, request_id):
|
||||||
|
"""
|
||||||
|
clean up the request after it is finished
|
||||||
|
"""
|
||||||
|
async with self.lock:
|
||||||
|
if request_id in self.request_map:
|
||||||
|
del self.request_map[request_id]
|
||||||
|
del self.request_num[request_id]
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""
|
||||||
|
close all connections and tasks
|
||||||
|
"""
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
for task in self.connection_tasks:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
async with self.lock:
|
||||||
|
for dealer in self.connections:
|
||||||
|
try:
|
||||||
|
dealer.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.connections.clear()
|
||||||
|
self.connection_load.clear()
|
||||||
|
self.request_map.clear()
|
||||||
|
|
||||||
|
api_server_logger.info("All connections and tasks closed")
|
@@ -85,7 +85,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# set trace attribute job_id.
|
# set trace attribute job_id.
|
||||||
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
|
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
|
||||||
# support max connections
|
# support max connections
|
||||||
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
|
"FD_SUPPORT_MAX_CONNECTIONS": lambda: int(os.getenv("FD_SUPPORT_MAX_CONNECTIONS", "1024")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
157
tests/entrypoints/openai/test_dealer_connection_manager.py
Normal file
157
tests/entrypoints/openai/test_dealer_connection_manager.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestDealerConnectionManager(unittest.TestCase):
|
||||||
|
"""Test cases for DealerConnectionManager"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self.manager = DealerConnectionManager(pid=1, max_connections=5)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.loop.run_until_complete(self.manager.close())
|
||||||
|
self.loop.close()
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_initialization(self, mock_create):
|
||||||
|
"""Test manager initialization creates connections"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
|
||||||
|
# Test initialization
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
# Verify connections were created
|
||||||
|
self.assertEqual(len(self.manager.connections), 10)
|
||||||
|
self.assertEqual(len(self.manager.connection_load), 10)
|
||||||
|
self.assertEqual(len(self.manager.connection_tasks), 10)
|
||||||
|
|
||||||
|
# Verify connection tasks are running
|
||||||
|
for task in self.manager.connection_tasks:
|
||||||
|
self.assertFalse(task.done())
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_get_connection(self, mock_create):
|
||||||
|
"""Test getting a connection with load balancing"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
# Get a connection
|
||||||
|
dealer, queue = await self.manager.get_connection("req1")
|
||||||
|
|
||||||
|
# Verify least loaded connection is returned
|
||||||
|
self.assertEqual(self.manager.connection_load[0], 1)
|
||||||
|
self.assertIsNotNone(dealer)
|
||||||
|
self.assertIsNotNone(queue)
|
||||||
|
self.assertIn("req1", self.manager.request_map)
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_connection_listening(self, mock_create):
|
||||||
|
"""Test connection listener handles responses"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
# Setup test response
|
||||||
|
test_response = {"request_id": "req1", "finished": True}
|
||||||
|
mock_stream.read.return_value = [b"", msgpack.packb(test_response)]
|
||||||
|
|
||||||
|
# Simulate response
|
||||||
|
dealer, queue = await self.manager.get_connection("req1")
|
||||||
|
response = await queue.get()
|
||||||
|
|
||||||
|
# Verify response handling
|
||||||
|
self.assertEqual(response[-1]["request_id"], "req1")
|
||||||
|
self.assertEqual(self.manager.connection_load[0], 0) # Should be decremented after finish
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_request_cleanup(self, mock_create):
|
||||||
|
"""Test request cleanup removes request tracking"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
await self.manager.get_connection("req1")
|
||||||
|
self.assertIn("req1", self.manager.request_map)
|
||||||
|
|
||||||
|
await self.manager.cleanup_request("req1")
|
||||||
|
self.assertNotIn("req1", self.manager.request_map)
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_multiple_requests(self, mock_create):
|
||||||
|
"""Test load balancing with multiple requests"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
# Get multiple connections
|
||||||
|
connections = []
|
||||||
|
for i in range(1, 6):
|
||||||
|
dealer, queue = await self.manager.get_connection(f"req{i}")
|
||||||
|
connections.append((dealer, queue))
|
||||||
|
|
||||||
|
# Verify load is distributed
|
||||||
|
load_counts = [0] * 5
|
||||||
|
for i in range(5):
|
||||||
|
load_counts[i] = self.manager.connection_load[i]
|
||||||
|
|
||||||
|
self.assertEqual(sum(load_counts), 5)
|
||||||
|
self.assertTrue(all(1 <= load <= 2 for load in load_counts))
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_connection_failure(self, mock_create):
|
||||||
|
"""Test connection failure handling"""
|
||||||
|
mock_create.side_effect = Exception("Connection failed")
|
||||||
|
|
||||||
|
with self.assertLogs(level="ERROR") as log:
|
||||||
|
await self.manager._add_connection(0)
|
||||||
|
self.assertTrue(any("Failed to create dealer" in msg for msg in log.output))
|
||||||
|
|
||||||
|
self.assertEqual(len(self.manager.connections), 0)
|
||||||
|
|
||||||
|
@patch("aiozmq.create_zmq_stream")
|
||||||
|
async def test_close_manager(self, mock_create):
|
||||||
|
"""Test manager shutdown"""
|
||||||
|
mock_stream = AsyncMock()
|
||||||
|
mock_create.return_value = mock_stream
|
||||||
|
await self.manager.initialize()
|
||||||
|
|
||||||
|
# Verify connections exist
|
||||||
|
self.assertEqual(len(self.manager.connections), 5)
|
||||||
|
|
||||||
|
# Close manager
|
||||||
|
await self.manager.close()
|
||||||
|
|
||||||
|
# Verify cleanup
|
||||||
|
self.assertEqual(len(self.manager.connections), 0)
|
||||||
|
self.assertEqual(len(self.manager.request_map), 0)
|
||||||
|
for task in self.manager.connection_tasks:
|
||||||
|
self.assertTrue(task.cancelled())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user