diff --git a/docs/parameters.md b/docs/parameters.md index 245eec83f..28a66b72c 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -8,6 +8,8 @@ When using FastDeploy to deploy models (including offline inference and service |:--------------|:----|:-----------| | ```port``` | `int` | Only required for service deployment, HTTP service port number, default: 8000 | | ```metrics_port``` | `int` | Only required for service deployment, metrics monitoring port number, default: 8001 | +| ```max_waiting_time``` | `int` | Only required for service deployment, maximum wait time for establishing a connection upon service request. Default: -1 (indicates no wait time limit).| +| ```max_concurrency``` | `int` | Only required for service deployment, the actual number of connections established by the service, default 512 | | ```engine_worker_queue_port``` | `int` | FastDeploy internal engine communication port, default: 8002 | | ```cache_queue_port``` | `int` | FastDeploy internal KVCache process communication port, default: 8003 | | ```max_model_len``` | `int` | Default maximum supported context length for inference, default: 2048 | diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 244d78ab7..177a2d97b 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -6,6 +6,8 @@ |:-----------------------------------|:----------| :----- | | ```port``` | `int` | 仅服务化部署需配置,服务HTTP请求端口号,默认8000 | | ```metrics_port``` | `int` | 仅服务化部署需配置,服务监控Metrics端口号,默认8001 | +| ```max_waiting_time``` | `int` | 仅服务化部署需配置,服务请求建立连接最大等待时间,默认-1 表示无等待时间限制| +| ```max_concurrency``` | `int` | 仅服务化部署需配置,服务实际建立连接数目,默认512 | | ```engine_worker_queue_port``` | `int` | FastDeploy内部引擎进程通信端口, 默认8002 | | ```cache_queue_port``` | `int` | FastDeploy内部KVCache进程通信端口, 默认8003 | | ```max_model_len``` | `int` | 推理默认最大支持上下文长度,默认2048 | diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index d5d051c8c..92d7ac007 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -21,12 +21,13 @@ import numpy as np from fastdeploy import envs from fastdeploy.engine.config import ModelConfig +from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.input.preprocess import InputPreprocessor from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.multimodal.registry import MultimodalRegistry from fastdeploy.platforms import current_platform -from fastdeploy.utils import EngineError, api_server_logger +from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger class EngineClient: @@ -47,6 +48,7 @@ class EngineClient: reasoning_parser=None, data_parallel_size=1, enable_logprob=False, + workers=1, ): import fastdeploy.model_executor.models # noqa: F401 @@ -77,7 +79,7 @@ class EngineClient: suffix=pid, create=False, ) - + self.semaphore = StatefulSemaphore((FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers) model_weights_status = np.zeros([1], dtype=np.int32) self.model_weights_status_signal = IPCSignal( name="model_weights_status", diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 0f9467c54..23f4a061a 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -14,15 +14,17 @@ # limitations under the License. """ +import asyncio import os import threading import time +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from multiprocessing import current_process import uvicorn import zmq -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import CONTENT_TYPE_LATEST @@ -49,6 +51,7 @@ from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, ins from fastdeploy.plugins.model_register import load_model_register_plugins from fastdeploy.utils import ( FlexibleArgumentParser, + StatefulSemaphore, api_server_logger, console_logger, is_port_available, @@ -61,6 +64,13 @@ parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the htt parser.add_argument("--workers", default=1, type=int, help="number of workers") parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server") parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") +parser.add_argument( + "--max-waiting-time", + default=-1, + type=int, + help="max waiting time for connection, if set value -1 means no waiting time limit", +) +parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model, args.revision) @@ -92,6 +102,12 @@ def load_engine(): return engine +app = FastAPI() + +MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers +connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) + + @asynccontextmanager async def lifespan(app: FastAPI): """ @@ -117,10 +133,11 @@ async def lifespan(app: FastAPI): args.reasoning_parser, args.data_parallel_size, args.enable_logprob, + args.workers, ) app.state.dynamic_load_weight = args.dynamic_load_weight - chat_handler = OpenAIServingChat(engine_client, pid, args.ips) - completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips) + chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time) + completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.pid = pid app.state.engine_client = engine_client @@ -142,6 +159,21 @@ app = FastAPI(lifespan=lifespan) instrument(app) +@asynccontextmanager +async def connection_manager(): + """ + async context manager for connection manager + """ + try: + await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001) + yield + except asyncio.TimeoutError: + api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}") + if connection_semaphore.locked(): + connection_semaphore.release() + raise HTTPException(status_code=429, detail="Too many requests") + + # TODO 传递真实引擎值 通过pid 获取状态 @app.get("/health") def health(request: Request) -> Response: @@ -195,6 +227,22 @@ def ping(raw_request: Request) -> Response: return health(raw_request) +def wrap_streaming_generator(original_generator: AsyncGenerator): + """ + Wrap an async generator to release the connection semaphore when the generator is finished. + """ + + async def wrapped_generator(): + try: + async for chunk in original_generator: + yield chunk + finally: + api_server_logger.debug(f"release: {connection_semaphore.status()}") + connection_semaphore.release() + + return wrapped_generator + + @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """ @@ -204,16 +252,23 @@ async def create_chat_completion(request: ChatCompletionRequest): status, msg = app.state.engine_client.is_workers_alive() if not status: return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) - inject_to_metadata(request) - generator = await app.state.chat_handler.create_chat_completion(request) + try: + async with connection_manager(): + inject_to_metadata(request) + generator = await app.state.chat_handler.create_chat_completion(request) + if isinstance(generator, ErrorResponse): + connection_semaphore.release() + return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code) + elif isinstance(generator, ChatCompletionResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - - elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") + except HTTPException as e: + api_server_logger.error(f"Error in chat completion: {str(e)}") + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) @app.post("/v1/completions") @@ -225,14 +280,20 @@ async def create_completion(request: CompletionRequest): status, msg = app.state.engine_client.is_workers_alive() if not status: return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) - - generator = await app.state.completion_handler.create_completion(request) - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump(), status_code=generator.code) - elif isinstance(generator, CompletionResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") + try: + async with connection_manager(): + generator = await app.state.completion_handler.create_completion(request) + if isinstance(generator, ErrorResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + elif isinstance(generator, CompletionResponse): + connection_semaphore.release() + return JSONResponse(content=generator.model_dump()) + else: + wrapped_generator = wrap_streaming_generator(generator) + return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream") + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) @app.get("/update_model_weight") diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 3e74c89df..8d8d4b98d 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -49,10 +49,11 @@ class OpenAIServingChat: OpenAI-style chat completions serving """ - def __init__(self, engine_client, pid, ips): + def __init__(self, engine_client, pid, ips, max_waiting_time): self.engine_client = engine_client self.pid = pid self.master_ip = ips + self.max_waiting_time = max_waiting_time self.host_ip = get_host_ip() if self.master_ip is not None: if isinstance(self.master_ip, list): @@ -93,6 +94,14 @@ class OpenAIServingChat: return ErrorResponse(code=400, message=str(e)) del current_req_dict + try: + api_server_logger.debug(f"{self.engine_client.semaphore.status()}") + if self.max_waiting_time < 0: + await self.engine_client.semaphore.acquire() + else: + await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") if request.stream: return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids) @@ -310,6 +319,8 @@ class OpenAIServingChat: yield f"data: {error_data}\n\n" finally: dealer.close() + self.engine_client.semaphore.release() + api_server_logger.info(f"release {self.engine_client.semaphore.status()}") yield "data: [DONE]\n\n" async def chat_completion_full_generator( @@ -383,6 +394,7 @@ class OpenAIServingChat: if task_is_finished: break finally: + self.engine_client.semaphore.release() dealer.close() choices = [] diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 2a0e5e9ec..43766f27f 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -40,11 +40,12 @@ from fastdeploy.worker.output import LogprobsLists class OpenAIServingCompletion: - def __init__(self, engine_client, pid, ips): + def __init__(self, engine_client, pid, ips, max_waiting_time): self.engine_client = engine_client self.pid = pid self.master_ip = ips self.host_ip = get_host_ip() + self.max_waiting_time = max_waiting_time if self.master_ip is not None: if isinstance(self.master_ip, list): self.master_ip = self.master_ip[0] @@ -114,6 +115,14 @@ class OpenAIServingCompletion: del current_req_dict + try: + if self.max_waiting_time < 0: + await self.engine_client.semaphore.acquire() + else: + await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time) + except Exception: + return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}") + if request.stream: return self.completion_stream_generator( request=request, @@ -221,6 +230,7 @@ class OpenAIServingCompletion: api_server_logger.error(f"Error in completion_full_generator: {e}", exc_info=True) raise finally: + self.engine_client.semaphore.release() if dealer is not None: dealer.close() @@ -371,6 +381,7 @@ class OpenAIServingCompletion: yield f"data: {ErrorResponse(message=str(e), code=400).model_dump_json(exclude_unset=True)}\n\n" finally: del request + self.engine_client.semaphore.release() if dealer is not None: dealer.close() yield "data: [DONE]\n\n" diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 6d17ac005..d9f5beb9c 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -84,6 +84,8 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), # set trace attribute job_id. "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), + # support max connections + "FD_SUPPORT_MAX_CONNECTIONS": lambda: 768, } diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 0eed87c55..27783b428 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -15,6 +15,7 @@ """ import argparse +import asyncio import codecs import importlib import logging @@ -304,6 +305,16 @@ def set_random_seed(seed: int) -> None: paddle.seed(seed) +def get_limited_max_value(max_value): + def validator(value): + value = float(value) + if value > max_value: + raise argparse.ArgumentTypeError(f"The value cannot exceed {max_value}") + return value + + return validator + + def download_model(url, output_dir, temp_tar): """ 下载模型,并将其解压到指定目录。 @@ -653,6 +664,61 @@ def deprecated_kwargs_warning(**kwargs): console_logger.warning(f"Deprecated argument is detected: {arg}, which may be removed later") +class StatefulSemaphore: + __slots__ = ("_semaphore", "_max_value", "_acquired_count", "_last_reset") + + """ + StatefulSemaphore is a class that wraps an asyncio.Semaphore and provides additional stateful information. + """ + + def __init__(self, value: int): + """ + StatefulSemaphore constructor + """ + if value < 0: + raise ValueError("Value must be non-negative.") + self._semaphore = asyncio.Semaphore(value) + self._max_value = value + self._acquired_count = 0 + self._last_reset = time.monotonic() + + async def acquire(self): + await self._semaphore.acquire() + self._acquired_count += 1 + + def release(self): + self._semaphore.release() + + self._acquired_count = max(0, self._acquired_count - 1) + + def locked(self) -> bool: + return self._semaphore.locked() + + @property + def available(self) -> int: + return self._max_value - self._acquired_count + + @property + def acquired(self) -> int: + return self._acquired_count + + @property + def max_value(self) -> int: + return self._max_value + + @property + def uptime(self) -> float: + return time.monotonic() - self._last_reset + + def status(self) -> dict: + return { + "available": self.available, + "acquired": self.acquired, + "max_value": self.max_value, + "uptime": round(self.uptime, 2), + } + + llm_logger = get_logger("fastdeploy", "fastdeploy.log") data_processor_logger = get_logger("data_processor", "data_processor.log") scheduler_logger = get_logger("scheduler", "scheduler.log")