Files
FastDeploy/fastdeploy/entrypoints/openai/api_server.py
2025-08-14 21:08:49 +08:00

473 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import asyncio
import os
import threading
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from multiprocessing import current_process
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
ControlSchedulerRequest,
ErrorResponse,
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.metrics.metrics import (
EXCLUDE_LABELS,
cleanup_prometheus_files,
get_filtered_metrics,
main_process_metrics,
)
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
from fastdeploy.utils import (
FlexibleArgumentParser,
StatefulSemaphore,
api_server_logger,
console_logger,
is_port_available,
retrive_model_from_server,
)
parser = FlexibleArgumentParser()
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser.add_argument(
"--max-waiting-time",
default=-1,
type=int,
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision)
if args.tool_parser_plugin:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
llm_engine = None
def load_engine():
"""
load engine
"""
global llm_engine
if llm_engine is not None:
return llm_engine
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return None
api_server_logger.info("FastDeploy LLM engine initialized!\n")
console_logger.info(f"Launching metrics service at http://{args.host}:{args.metrics_port}/metrics")
console_logger.info(f"Launching chat completion service at http://{args.host}:{args.port}/v1/chat/completions")
console_logger.info(f"Launching completion service at http://{args.host}:{args.port}/v1/completions")
llm_engine = engine
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
async context manager for FastAPI lifespan
"""
if args.tokenizer is None:
args.tokenizer = args.model
if current_process().name != "MainProcess":
pid = os.getppid()
else:
pid = os.getpid()
api_server_logger.info(f"{pid}")
engine_client = EngineClient(
args.tokenizer,
args.max_model_len,
args.tensor_parallel_size,
pid,
args.limit_mm_per_prompt,
args.mm_processor_kwargs,
args.enable_mm,
args.reasoning_parser,
args.data_parallel_size,
args.enable_logprob,
args.workers,
args.tool_call_parser,
)
app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
app.state.engine_client = engine_client
app.state.chat_handler = chat_handler
app.state.completion_handler = completion_handler
yield
# close zmq
try:
engine_client.zmq_client.close()
from prometheus_client import multiprocess
multiprocess.mark_process_dead(os.getpid())
api_server_logger.info(f"Closing metrics client pid: {pid}")
except Exception as e:
api_server_logger.warning(e)
app = FastAPI(lifespan=lifespan)
instrument(app)
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
@asynccontextmanager
async def connection_manager():
"""
async context manager for connection manager
"""
try:
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
yield
except asyncio.TimeoutError:
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
raise HTTPException(
status_code=429, detail=f"Too many requests, current max concurrency is {args.max_concurrency}"
)
def wrap_streaming_generator(original_generator: AsyncGenerator):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
"""
async def wrapped_generator():
try:
async for chunk in original_generator:
yield chunk
finally:
api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
connection_semaphore.release()
return wrapped_generator
# TODO 传递真实引擎值 通过pid 获取状态
@app.get("/health")
def health(request: Request) -> Response:
"""Health check."""
status, msg = app.state.engine_client.check_health()
if not status:
return Response(content=msg, status_code=404)
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return Response(content=msg, status_code=304)
return Response(status_code=200)
@app.get("/load")
async def list_all_routes():
"""
列出所有以/v1开头的路由信息
Args:
无参数
Returns:
dict: 包含所有符合条件的路由信息的字典,格式如下:
{
"routes": [
{
"path": str, # 路由路径
"methods": list, # 支持的HTTP方法列表已排序
"tags": list # 路由标签列表,默认为空列表
},
...
]
}
"""
routes_info = []
for route in app.routes:
# 直接检查路径是否以/v1开头
if route.path.startswith("/v1"):
methods = sorted(route.methods)
tags = getattr(route, "tags", []) or []
routes_info.append({"path": route.path, "methods": methods, "tags": tags})
return {"routes": routes_info}
@app.api_route("/ping", methods=["GET", "POST"])
def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return health(raw_request)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""
Create a chat completion for the provided prompt and parameters.
"""
api_server_logger.info(f"Chat Received request: {request.model_dump_json()}")
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
try:
async with connection_manager():
inject_to_metadata(request)
generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse):
connection_semaphore.release()
api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
connection_semaphore.release()
api_server_logger.debug(f"current concurrency status: {connection_semaphore.status()}")
return JSONResponse(content=generator.model_dump())
else:
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except HTTPException as e:
api_server_logger.error(f"Error in chat completion: {str(e)}")
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
"""
Create a completion for the provided prompt and parameters.
"""
api_server_logger.info(f"Completion Received request: {request.model_dump_json()}")
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.is_workers_alive()
if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
try:
async with connection_manager():
generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
elif isinstance(generator, CompletionResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump())
else:
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
@app.get("/update_model_weight")
def update_model_weight(request: Request) -> Response:
"""
update model weight
"""
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.update_model_weight()
if not status:
return Response(content=msg, status_code=404)
return Response(status_code=200)
else:
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
@app.get("/clear_load_weight")
def clear_load_weight(request: Request) -> Response:
"""
clear model weight
"""
if app.state.dynamic_load_weight:
status, msg = app.state.engine_client.clear_load_weight()
if not status:
return Response(content=msg, status_code=404)
return Response(status_code=200)
else:
return Response(content="Dynamic Load Weight Disabled.", status_code=404)
def launch_api_server() -> None:
"""
启动http服务
"""
if not is_port_available(args.host, args.port):
raise Exception(f"The parameter `port`:{args.port} is already in use.")
api_server_logger.info(f"launch Fastdeploy api server... port: {args.port}")
api_server_logger.info(f"args: {args.__dict__}")
fd_start_span("FD_START")
try:
uvicorn.run(
app="fastdeploy.entrypoints.openai.api_server:app",
host=args.host,
port=args.port,
workers=args.workers,
log_level="info",
) # set log level to error to avoid log
except Exception as e:
api_server_logger.error(f"launch sync http server error, {e}")
metrics_app = FastAPI()
@metrics_app.get("/metrics")
async def metrics():
"""
metrics
"""
metrics_text = get_filtered_metrics(
EXCLUDE_LABELS,
extra_register_func=lambda reg: main_process_metrics.register_all(reg, workers=args.workers),
)
return Response(metrics_text, media_type=CONTENT_TYPE_LATEST)
def run_metrics_server():
"""
run metrics server
"""
uvicorn.run(metrics_app, host="0.0.0.0", port=args.metrics_port, log_level="error")
def launch_metrics_server():
"""Metrics server running the sub thread"""
if not is_port_available(args.host, args.metrics_port):
raise Exception(f"The parameter `metrics_port`:{args.metrics_port} is already in use.")
prom_dir = cleanup_prometheus_files(True)
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prom_dir
metrics_server_thread = threading.Thread(target=run_metrics_server, daemon=True)
metrics_server_thread.start()
time.sleep(1)
controller_app = FastAPI()
@controller_app.post("/controller/reset_scheduler")
def reset_scheduler():
"""
reset scheduler
"""
global llm_engine
if llm_engine is None:
return Response("Engine not loaded", status_code=500)
llm_engine.scheduler.reset()
return Response("Scheduler Reset Successfully", status_code=200)
@controller_app.post("/controller/scheduler")
def control_scheduler(request: ControlSchedulerRequest):
"""
Control the scheduler behavior with the given parameters.
"""
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)
global llm_engine
if llm_engine is None:
content.message = "Engine is not loaded"
content.code = 500
return JSONResponse(content=content.model_dump(), status_code=500)
if request.reset:
llm_engine.scheduler.reset()
if request.load_shards_num or request.reallocate_shard:
if hasattr(llm_engine.scheduler, "update_config") and callable(llm_engine.scheduler.update_config):
llm_engine.scheduler.update_config(
load_shards_num=request.load_shards_num,
reallocate=request.reallocate_shard,
)
else:
content.message = "This scheduler doesn't support the `update_config()` method."
content.code = 400
return JSONResponse(content=content.model_dump(), status_code=400)
return JSONResponse(content=content.model_dump(), status_code=200)
def run_controller_server():
"""
run controller server
"""
uvicorn.run(
controller_app,
host="0.0.0.0",
port=args.controller_port,
log_level="error",
)
def launch_controller_server():
"""Controller server running the sub thread"""
if args.controller_port < 0:
return
if not is_port_available(args.host, args.controller_port):
raise Exception(f"The parameter `controller_port`:{args.controller_port} is already in use.")
controller_server_thread = threading.Thread(target=run_controller_server, daemon=True)
controller_server_thread.start()
time.sleep(1)
def main():
"""main函数"""
if load_engine() is None:
return
launch_controller_server()
launch_metrics_server()
launch_api_server()
if __name__ == "__main__":
main()