mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Use shared memory to enforce global concurrency limit across workers
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -61,6 +61,9 @@ from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
|
||||
from fastdeploy.envs import environment_variables
|
||||
from fastdeploy.metrics.metrics import get_filtered_metrics
|
||||
from filelock import FileLock
|
||||
|
||||
from fastdeploy.inter_communicator import IPCSignal, shared_memory_exists
|
||||
from fastdeploy.utils import (
|
||||
ExceptionHandler,
|
||||
FlexibleArgumentParser,
|
||||
@@ -85,6 +88,26 @@ if args.tool_parser_plugin:
|
||||
llm_engine = None
|
||||
|
||||
MAX_CONCURRENT_CONNECTIONS = args.max_concurrency
|
||||
|
||||
# Use shared memory for concurrency control across multiple workers
|
||||
# Create or connect to a shared counter that tracks active connections globally
|
||||
import numpy as np
|
||||
_shm_name = f"fd_api_server_connections_{args.port}"
|
||||
_create_shm = not shared_memory_exists(_shm_name)
|
||||
connection_counter_shm = IPCSignal(
|
||||
name=_shm_name,
|
||||
array=np.array([0], dtype=np.int32) if _create_shm else None,
|
||||
dtype=np.int32 if _create_shm else None,
|
||||
create=_create_shm,
|
||||
shm_size=4 if not _create_shm else None,
|
||||
)
|
||||
if not _create_shm:
|
||||
# Attach to existing shared memory
|
||||
connection_counter_shm.value = np.ndarray((1,), dtype=np.int32, buffer=connection_counter_shm.shm.buf)
|
||||
|
||||
# File-based lock for atomic operations on the shared counter
|
||||
connection_counter_lock = FileLock(f"/tmp/fd_api_server_conn_lock_{args.port}.lock")
|
||||
|
||||
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
|
||||
|
||||
|
||||
@@ -257,16 +280,28 @@ if tokens := [key for key in (args.api_key or env_tokens) if key]:
|
||||
@asynccontextmanager
|
||||
async def connection_manager():
|
||||
"""
|
||||
async context manager for connection manager
|
||||
async context manager for connection manager using shared memory for global concurrency control across workers
|
||||
"""
|
||||
# Atomically increment and check the shared counter using file lock
|
||||
with connection_counter_lock:
|
||||
current_count = connection_counter_shm.value[0]
|
||||
if current_count >= MAX_CONCURRENT_CONNECTIONS:
|
||||
api_server_logger.info(
|
||||
f"Reach max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
|
||||
)
|
||||
# Increment the counter
|
||||
connection_counter_shm.value[0] = current_count + 1
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
|
||||
yield
|
||||
except asyncio.TimeoutError:
|
||||
api_server_logger.info(f"Reach max request concurrency, semaphore status: {connection_semaphore.status()}")
|
||||
raise HTTPException(
|
||||
status_code=429, detail=f"Too many requests,current max concurrency is {args.max_concurrency}"
|
||||
)
|
||||
finally:
|
||||
# Decrement the counter on exit
|
||||
with connection_counter_lock:
|
||||
connection_counter_shm.value[0] -= 1
|
||||
|
||||
|
||||
# TODO 传递真实引擎值 通过pid 获取状态
|
||||
|
||||
Reference in New Issue
Block a user