diff --git a/fastdeploy/entrypoints/api_server.py b/fastdeploy/entrypoints/api_server.py index 4f4d7f225..a8fe7da9d 100644 --- a/fastdeploy/entrypoints/api_server.py +++ b/fastdeploy/entrypoints/api_server.py @@ -15,6 +15,8 @@ """ import json +import os +import signal import traceback import uvicorn @@ -34,6 +36,30 @@ app = FastAPI() llm_engine = None +def cleanup_engine(): + """清理引擎资源""" + global llm_engine + if llm_engine is not None: + try: + if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc is not None: + try: + pgid = os.getpgid(llm_engine.worker_proc.pid) + api_server_logger.info(f"Terminating worker process group {pgid}") + os.killpg(pgid, signal.SIGTERM) + except Exception as e: + api_server_logger.error(f"Error terminating worker process: {e}") + except Exception as e: + api_server_logger.error(f"Error during cleanup: {e}") + + +def signal_handler(signum, frame): + """处理SIGINT和SIGTERM信号""" + sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM" + api_server_logger.info(f"Received {sig_name}, initiating graceful shutdown...") + cleanup_engine() + # 让uvicorn处理实际的退出 + + def init_app(args): """ init LLMEngine @@ -96,6 +122,10 @@ def launch_api_server(args) -> None: """ 启动http服务 """ + # 设置信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + if not is_port_available(args.host, args.port): raise Exception(f"The parameter `port`:{args.port} is already in use.") @@ -116,6 +146,8 @@ def launch_api_server(args) -> None: ) # set log level to error to avoid log except Exception as e: api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}") + finally: + cleanup_engine() def main(): diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index d6e03eb95..35e9c836f 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -17,6 +17,7 @@ import asyncio import json import os import signal +import subprocess import threading import time import traceback @@ -83,11 +84,54 @@ chat_template = load_chat_template(args.chat_template, args.model) if args.tool_parser_plugin: ToolParserManager.import_tool_parser(args.tool_parser_plugin) llm_engine = None +shutdown_event = threading.Event() MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS) +def cleanup_processes(): + """清理所有子进程""" + global llm_engine + if llm_engine is not None: + try: + if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc is not None: + try: + pgid = os.getpgid(llm_engine.worker_proc.pid) + api_server_logger.info(f"Terminating worker process group {pgid}") + os.killpg(pgid, signal.SIGTERM) + # 等待进程结束,如果超时则强制杀死 + try: + llm_engine.worker_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + api_server_logger.warning("Worker process did not terminate in time, sending SIGKILL") + os.killpg(pgid, signal.SIGKILL) + llm_engine.worker_proc.wait(timeout=2) + except Exception as e: + api_server_logger.error(f"Error terminating worker process: {e}") + + if hasattr(llm_engine, "cache_manager_processes") and llm_engine.cache_manager_processes: + for proc in llm_engine.cache_manager_processes: + try: + proc.terminate() + proc.join(timeout=2) + if proc.is_alive(): + proc.kill() + except Exception as e: + api_server_logger.error(f"Error terminating cache manager process: {e}") + except Exception as e: + api_server_logger.error(f"Error during cleanup: {e}") + + +def signal_handler(signum, frame): + """处理SIGINT和SIGTERM信号""" + sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM" + api_server_logger.info(f"Received {sig_name}, initiating graceful shutdown...") + shutdown_event.set() + cleanup_processes() + # 不在这里退出,让gunicorn处理 + + class StandaloneApplication(BaseApplication): def __init__(self, app, options=None): self.application = app @@ -98,9 +142,31 @@ class StandaloneApplication(BaseApplication): config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) - + def load(self): return self.application + + def run(self): + """重写run方法以添加信号处理""" + # 在主进程中设置信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + super().run() + finally: + # 确保退出时清理所有资源 + cleanup_processes() + + # Gunicorn hooks + def worker_exit(self, server, worker): + """当worker进程退出时调用""" + api_server_logger.info(f"Worker {worker.pid} exiting") + + def on_exit(self, server): + """当Arbiter退出时调用""" + api_server_logger.info("Gunicorn master process exiting, cleaning up...") + cleanup_processes() def load_engine(): @@ -713,8 +779,8 @@ def launch_worker_monitor(): """ def _monitor(): - global llm_engine - while True: + global llm_engine, shutdown_event + while not shutdown_event.is_set(): if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc.poll() is not None: console_logger.error( f"Worker process has died in the background (code={llm_engine.worker_proc.returncode}). API server is forced to stop." diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index c3a3b5076..37bd7b44e 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -17,6 +17,8 @@ import argparse import json import os +import signal +import sys import time from typing import Tuple @@ -1032,6 +1034,15 @@ def run_worker_proc() -> None: """ start worker process """ + # 设置信号处理器以优雅退出 + def signal_handler(signum, frame): + sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM" + logger.info(f"Worker process received {sig_name}, shutting down gracefully...") + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + # Get args form Engine args = parse_args()