diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 663806d6b..36262c439 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -301,41 +301,8 @@ if tokens := [key for key in (args.api_key or env_tokens) if key]: app.add_middleware(AuthenticationMiddleware, tokens) -def acquire_connection(): - """ - Acquire a connection slot using shared memory for global concurrency control across workers. - - This function is thread-safe and uses a file-based lock for synchronization across processes. - It will block while acquiring the lock, which may impact performance under high load. - - Raises: - HTTPException: With status 429 if the global connection limit is reached. - """ - with connection_counter_lock: - current_count = connection_counter_shm.value[0] - if current_count >= MAX_CONCURRENT_CONNECTIONS: - api_server_logger.info( - f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}" - ) - raise HTTPException( - status_code=429, - detail=f"Too many requests, current max concurrency is {args.max_concurrency}" - ) - # Increment the counter - connection_counter_shm.value[0] = current_count + 1 -def release_connection(): - """ - Release a connection slot by decrementing the shared counter. - Includes bounds checking to prevent counter from going negative. - """ - with connection_counter_lock: - if connection_counter_shm.value[0] > 0: - connection_counter_shm.value[0] -= 1 - else: - api_server_logger.warning("Attempted to release connection when counter is already 0") - @@ -393,9 +360,13 @@ def ping(raw_request: Request) -> Response: return health(raw_request) -def wrap_streaming_generator(original_generator: AsyncGenerator): +def wrap_streaming_generator(original_generator: AsyncGenerator, release_callback): """ - Wrap an async generator to release the connection semaphore when the generator is finished. + Wrap an async generator to add tracing and ensure connection is released when streaming completes. + + Args: + original_generator: The async generator producing the stream + release_callback: Function to call when streaming completes to release the connection """ async def wrapped_generator(): @@ -423,13 +394,15 @@ def wrap_streaming_generator(original_generator: AsyncGenerator): # 尾包捕获 if span is not None and span.is_recording() and count > 0: span.add_event("last_chunk", {"time": last_time, "total_chunk": count}) - release_connection() + # Release the connection when streaming completes + release_callback() else: try: async for chunk in original_generator: yield chunk finally: - release_connection() + # Release the connection when streaming completes + release_callback() return wrapped_generator @@ -450,34 +423,38 @@ async def create_chat_completion(request: ChatCompletionRequest, req: Request): if not status: return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) try: - acquire_connection() - connection_acquired = True - except HTTPException as e: - # If acquire fails with 429, connection was not acquired - api_server_logger.error(f"Failed to acquire connection slot for chat completion: {str(e)}") - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) - try: - try: - tracing.label_span(request) - generator = await app.state.chat_handler.create_chat_completion(request) - if isinstance(generator, ErrorResponse): - release_connection() - connection_acquired = False - return JSONResponse(content=generator.model_dump(), status_code=500) - elif isinstance(generator, ChatCompletionResponse): - release_connection() - connection_acquired = False - return JSONResponse(content=generator.model_dump()) - else: - # For streaming, release happens in wrap_streaming_generator - connection_acquired = False - wrapped_generator = wrap_streaming_generator(generator) - return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - except Exception: - # Release connection only if it was acquired - if connection_acquired: - release_connection() - raise + # Acquire connection using shared memory counter + with connection_counter_lock: + current_count = connection_counter_shm.value[0] + if current_count >= MAX_CONCURRENT_CONNECTIONS: + api_server_logger.info( + f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}" + ) + raise HTTPException( + status_code=429, + detail=f"Too many requests, current max concurrency is {args.max_concurrency}" + ) + connection_counter_shm.value[0] = current_count + 1 + + tracing.label_span(request) + generator = await app.state.chat_handler.create_chat_completion(request) + + # Define release callback + def release_connection(): + with connection_counter_lock: + if connection_counter_shm.value[0] > 0: + connection_counter_shm.value[0] -= 1 + + if isinstance(generator, ErrorResponse): + release_connection() + return JSONResponse(content=generator.model_dump(), status_code=500) + elif isinstance(generator, ChatCompletionResponse): + release_connection() + return JSONResponse(content=generator.model_dump()) + else: + # For streaming, release happens when generator completes + wrapped_generator = wrap_streaming_generator(generator, release_connection) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") except HTTPException as e: api_server_logger.error(f"Error in chat completion: {str(e)}") return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) @@ -499,33 +476,38 @@ async def create_completion(request: CompletionRequest, req: Request): if not status: return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) try: - acquire_connection() - connection_acquired = True - except HTTPException as e: - # If acquire fails with 429, connection was not acquired - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) - try: - try: - tracing.label_span(request) - generator = await app.state.completion_handler.create_completion(request) - if isinstance(generator, ErrorResponse): - release_connection() - connection_acquired = False - return JSONResponse(content=generator.model_dump(), status_code=500) - elif isinstance(generator, CompletionResponse): - release_connection() - connection_acquired = False - return JSONResponse(content=generator.model_dump()) - else: - # For streaming, release happens in wrap_streaming_generator - connection_acquired = False - wrapped_generator = wrap_streaming_generator(generator) - return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - except Exception: - # Release connection only if it was acquired - if connection_acquired: - release_connection() - raise + # Acquire connection using shared memory counter + with connection_counter_lock: + current_count = connection_counter_shm.value[0] + if current_count >= MAX_CONCURRENT_CONNECTIONS: + api_server_logger.info( + f"Reached max request concurrency: {current_count}/{MAX_CONCURRENT_CONNECTIONS}" + ) + raise HTTPException( + status_code=429, + detail=f"Too many requests, current max concurrency is {args.max_concurrency}" + ) + connection_counter_shm.value[0] = current_count + 1 + + tracing.label_span(request) + generator = await app.state.completion_handler.create_completion(request) + + # Define release callback + def release_connection(): + with connection_counter_lock: + if connection_counter_shm.value[0] > 0: + connection_counter_shm.value[0] -= 1 + + if isinstance(generator, ErrorResponse): + release_connection() + return JSONResponse(content=generator.model_dump(), status_code=500) + elif isinstance(generator, CompletionResponse): + release_connection() + return JSONResponse(content=generator.model_dump()) + else: + # For streaming, release happens when generator completes + wrapped_generator = wrap_streaming_generator(generator, release_connection) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") except HTTPException as e: return JSONResponse(status_code=e.status_code, content={"detail": e.detail})