Simplify implementation: use inline acquire/release with shared memory counter

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-12-17 11:09:17 +00:00
parent 6f9b25902a
commit 53f4a9ad27

View File

@@ -301,41 +301,8 @@ if tokens := [key for key in (args.api_key or env_tokens) if key]:
app.add_middleware(AuthenticationMiddleware, tokens)
def acquire_connection():
"""
Acquire a connection slot using shared memory for global concurrency control across workers.
This function is thread-safe and uses a file-based lock for synchronization across processes.
It will block while acquiring the lock, which may impact performance under high load.
Raises:
HTTPException: With status 429 if the global connection limit is reached.
"""
with connection_counter_lock:
current_count = connection_counter_shm.value[0]
if current_count >= MAX_CONCURRENT_CONNECTIONS:
api_server_logger.info(
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
)
raise HTTPException(
status_code=429,
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
)
# Increment the counter
connection_counter_shm.value[0] = current_count + 1
def release_connection():
"""
Release a connection slot by decrementing the shared counter.
Includes bounds checking to prevent counter from going negative.
"""
with connection_counter_lock:
if connection_counter_shm.value[0] > 0:
connection_counter_shm.value[0] -= 1
else:
api_server_logger.warning("Attempted to release connection when counter is already 0")
@@ -393,9 +360,13 @@ def ping(raw_request: Request) -> Response:
return health(raw_request)
def wrap_streaming_generator(original_generator: AsyncGenerator):
def wrap_streaming_generator(original_generator: AsyncGenerator, release_callback):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
Wrap an async generator to add tracing and ensure connection is released when streaming completes.
Args:
original_generator: The async generator producing the stream
release_callback: Function to call when streaming completes to release the connection
"""
async def wrapped_generator():
@@ -423,13 +394,15 @@ def wrap_streaming_generator(original_generator: AsyncGenerator):
# 尾包捕获
if span is not None and span.is_recording() and count > 0:
span.add_event("last_chunk", {"time": last_time, "total_chunk": count})
release_connection()
# Release the connection when streaming completes
release_callback()
else:
try:
async for chunk in original_generator:
yield chunk
finally:
release_connection()
# Release the connection when streaming completes
release_callback()
return wrapped_generator
@@ -450,34 +423,38 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request):
if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
try:
acquire_connection()
connection_acquired = True
except HTTPException as e:
# If acquire fails with 429, connection was not acquired
api_server_logger.error(f"Failed to acquire connection slot for chat completion: {str(e)}")
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
try:
try:
tracing.label_span(request)
generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse):
release_connection()
connection_acquired = False
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, ChatCompletionResponse):
release_connection()
connection_acquired = False
return JSONResponse(content=generator.model_dump())
else:
# For streaming, release happens in wrap_streaming_generator
connection_acquired = False
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except Exception:
# Release connection only if it was acquired
if connection_acquired:
release_connection()
raise
# Acquire connection using shared memory counter
with connection_counter_lock:
current_count = connection_counter_shm.value[0]
if current_count >= MAX_CONCURRENT_CONNECTIONS:
api_server_logger.info(
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
)
raise HTTPException(
status_code=429,
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
)
connection_counter_shm.value[0] = current_count + 1
tracing.label_span(request)
generator = await app.state.chat_handler.create_chat_completion(request)
# Define release callback
def release_connection():
with connection_counter_lock:
if connection_counter_shm.value[0] > 0:
connection_counter_shm.value[0] -= 1
if isinstance(generator, ErrorResponse):
release_connection()
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, ChatCompletionResponse):
release_connection()
return JSONResponse(content=generator.model_dump())
else:
# For streaming, release happens when generator completes
wrapped_generator = wrap_streaming_generator(generator, release_connection)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except HTTPException as e:
api_server_logger.error(f"Error in chat completion: {str(e)}")
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
@@ -499,33 +476,38 @@ async def create_completion(request: CompletionRequest, req: Request):
if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
try:
acquire_connection()
connection_acquired = True
except HTTPException as e:
# If acquire fails with 429, connection was not acquired
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
try:
try:
tracing.label_span(request)
generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse):
release_connection()
connection_acquired = False
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, CompletionResponse):
release_connection()
connection_acquired = False
return JSONResponse(content=generator.model_dump())
else:
# For streaming, release happens in wrap_streaming_generator
connection_acquired = False
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except Exception:
# Release connection only if it was acquired
if connection_acquired:
release_connection()
raise
# Acquire connection using shared memory counter
with connection_counter_lock:
current_count = connection_counter_shm.value[0]
if current_count >= MAX_CONCURRENT_CONNECTIONS:
api_server_logger.info(
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
)
raise HTTPException(
status_code=429,
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
)
connection_counter_shm.value[0] = current_count + 1
tracing.label_span(request)
generator = await app.state.completion_handler.create_completion(request)
# Define release callback
def release_connection():
with connection_counter_lock:
if connection_counter_shm.value[0] > 0:
connection_counter_shm.value[0] -= 1
if isinstance(generator, ErrorResponse):
release_connection()
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, CompletionResponse):
release_connection()
return JSONResponse(content=generator.model_dump())
else:
# For streaming, release happens when generator completes
wrapped_generator = wrap_streaming_generator(generator, release_connection)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})