mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Simplify implementation: use inline acquire/release with shared memory counter
Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -301,40 +301,7 @@ if tokens := [key for key in (args.api_key or env_tokens) if key]:
|
||||
app.add_middleware(AuthenticationMiddleware, tokens)
|
||||
|
||||
|
||||
def acquire_connection():
|
||||
"""
|
||||
Acquire a connection slot using shared memory for global concurrency control across workers.
|
||||
|
||||
This function is thread-safe and uses a file-based lock for synchronization across processes.
|
||||
It will block while acquiring the lock, which may impact performance under high load.
|
||||
|
||||
Raises:
|
||||
HTTPException: With status 429 if the global connection limit is reached.
|
||||
"""
|
||||
with connection_counter_lock:
|
||||
current_count = connection_counter_shm.value[0]
|
||||
if current_count >= MAX_CONCURRENT_CONNECTIONS:
|
||||
api_server_logger.info(
|
||||
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
|
||||
)
|
||||
# Increment the counter
|
||||
connection_counter_shm.value[0] = current_count + 1
|
||||
|
||||
|
||||
def release_connection():
|
||||
"""
|
||||
Release a connection slot by decrementing the shared counter.
|
||||
Includes bounds checking to prevent counter from going negative.
|
||||
"""
|
||||
with connection_counter_lock:
|
||||
if connection_counter_shm.value[0] > 0:
|
||||
connection_counter_shm.value[0] -= 1
|
||||
else:
|
||||
api_server_logger.warning("Attempted to release connection when counter is already 0")
|
||||
|
||||
|
||||
|
||||
@@ -393,9 +360,13 @@ def ping(raw_request: Request) -> Response:
|
||||
return health(raw_request)
|
||||
|
||||
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator, release_callback):
|
||||
"""
|
||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||
Wrap an async generator to add tracing and ensure connection is released when streaming completes.
|
||||
|
||||
Args:
|
||||
original_generator: The async generator producing the stream
|
||||
release_callback: Function to call when streaming completes to release the connection
|
||||
"""
|
||||
|
||||
async def wrapped_generator():
|
||||
@@ -423,13 +394,15 @@ def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
# 尾包捕获
|
||||
if span is not None and span.is_recording() and count > 0:
|
||||
span.add_event("last_chunk", {"time": last_time, "total_chunk": count})
|
||||
release_connection()
|
||||
# Release the connection when streaming completes
|
||||
release_callback()
|
||||
else:
|
||||
try:
|
||||
async for chunk in original_generator:
|
||||
yield chunk
|
||||
finally:
|
||||
release_connection()
|
||||
# Release the connection when streaming completes
|
||||
release_callback()
|
||||
|
||||
return wrapped_generator
|
||||
|
||||
@@ -450,34 +423,38 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request):
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
try:
|
||||
acquire_connection()
|
||||
connection_acquired = True
|
||||
except HTTPException as e:
|
||||
# If acquire fails with 429, connection was not acquired
|
||||
api_server_logger.error(f"Failed to acquire connection slot for chat completion: {str(e)}")
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
try:
|
||||
try:
|
||||
tracing.label_span(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
release_connection()
|
||||
connection_acquired = False
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
release_connection()
|
||||
connection_acquired = False
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
# For streaming, release happens in wrap_streaming_generator
|
||||
connection_acquired = False
|
||||
wrapped_generator = wrap_streaming_generator(generator)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
except Exception:
|
||||
# Release connection only if it was acquired
|
||||
if connection_acquired:
|
||||
release_connection()
|
||||
raise
|
||||
# Acquire connection using shared memory counter
|
||||
with connection_counter_lock:
|
||||
current_count = connection_counter_shm.value[0]
|
||||
if current_count >= MAX_CONCURRENT_CONNECTIONS:
|
||||
api_server_logger.info(
|
||||
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
|
||||
)
|
||||
connection_counter_shm.value[0] = current_count + 1
|
||||
|
||||
tracing.label_span(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
|
||||
# Define release callback
|
||||
def release_connection():
|
||||
with connection_counter_lock:
|
||||
if connection_counter_shm.value[0] > 0:
|
||||
connection_counter_shm.value[0] -= 1
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
release_connection()
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
release_connection()
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
# For streaming, release happens when generator completes
|
||||
wrapped_generator = wrap_streaming_generator(generator, release_connection)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
except HTTPException as e:
|
||||
api_server_logger.error(f"Error in chat completion: {str(e)}")
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
@@ -499,33 +476,38 @@ async def create_completion(request: CompletionRequest, req: Request):
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
try:
|
||||
acquire_connection()
|
||||
connection_acquired = True
|
||||
except HTTPException as e:
|
||||
# If acquire fails with 429, connection was not acquired
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
try:
|
||||
try:
|
||||
tracing.label_span(request)
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
release_connection()
|
||||
connection_acquired = False
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
release_connection()
|
||||
connection_acquired = False
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
# For streaming, release happens in wrap_streaming_generator
|
||||
connection_acquired = False
|
||||
wrapped_generator = wrap_streaming_generator(generator)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
except Exception:
|
||||
# Release connection only if it was acquired
|
||||
if connection_acquired:
|
||||
release_connection()
|
||||
raise
|
||||
# Acquire connection using shared memory counter
|
||||
with connection_counter_lock:
|
||||
current_count = connection_counter_shm.value[0]
|
||||
if current_count >= MAX_CONCURRENT_CONNECTIONS:
|
||||
api_server_logger.info(
|
||||
f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
|
||||
)
|
||||
connection_counter_shm.value[0] = current_count + 1
|
||||
|
||||
tracing.label_span(request)
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
|
||||
# Define release callback
|
||||
def release_connection():
|
||||
with connection_counter_lock:
|
||||
if connection_counter_shm.value[0] > 0:
|
||||
connection_counter_shm.value[0] -= 1
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
release_connection()
|
||||
return JSONResponse(content=generator.model_dump(), status_code=500)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
release_connection()
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
# For streaming, release happens when generator completes
|
||||
wrapped_generator = wrap_streaming_generator(generator, release_connection)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
except HTTPException as e:
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user