mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Add signal handlers for graceful process termination
- Added SIGINT/SIGTERM signal handlers in api_server.py (both OpenAI and simple versions) - Added cleanup_processes() function to properly terminate worker processes - Enhanced StandaloneApplication with worker exit hooks and cleanup - Added signal handling in worker_process.py for graceful worker shutdown - Added shutdown_event to coordinate graceful shutdown across threads - Improved worker monitor to respect shutdown event Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import traceback
|
||||
|
||||
import uvicorn
|
||||
@@ -34,6 +36,30 @@ app = FastAPI()
|
||||
llm_engine = None
|
||||
|
||||
|
||||
def cleanup_engine():
|
||||
"""清理引擎资源"""
|
||||
global llm_engine
|
||||
if llm_engine is not None:
|
||||
try:
|
||||
if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc is not None:
|
||||
try:
|
||||
pgid = os.getpgid(llm_engine.worker_proc.pid)
|
||||
api_server_logger.info(f"Terminating worker process group {pgid}")
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error terminating worker process: {e}")
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""处理SIGINT和SIGTERM信号"""
|
||||
sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
|
||||
api_server_logger.info(f"Received {sig_name}, initiating graceful shutdown...")
|
||||
cleanup_engine()
|
||||
# 让uvicorn处理实际的退出
|
||||
|
||||
|
||||
def init_app(args):
|
||||
"""
|
||||
init LLMEngine
|
||||
@@ -96,6 +122,10 @@ def launch_api_server(args) -> None:
|
||||
"""
|
||||
启动http服务
|
||||
"""
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
if not is_port_available(args.host, args.port):
|
||||
raise Exception(f"The parameter `port`:{args.port} is already in use.")
|
||||
|
||||
@@ -116,6 +146,8 @@ def launch_api_server(args) -> None:
|
||||
) # set log level to error to avoid log
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"launch sync http server error, {e}, {str(traceback.format_exc())}")
|
||||
finally:
|
||||
cleanup_engine()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -17,6 +17,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@@ -83,11 +84,54 @@ chat_template = load_chat_template(args.chat_template, args.model)
|
||||
if args.tool_parser_plugin:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
llm_engine = None
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
|
||||
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
|
||||
|
||||
|
||||
def cleanup_processes():
|
||||
"""清理所有子进程"""
|
||||
global llm_engine
|
||||
if llm_engine is not None:
|
||||
try:
|
||||
if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc is not None:
|
||||
try:
|
||||
pgid = os.getpgid(llm_engine.worker_proc.pid)
|
||||
api_server_logger.info(f"Terminating worker process group {pgid}")
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
# 等待进程结束,如果超时则强制杀死
|
||||
try:
|
||||
llm_engine.worker_proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
api_server_logger.warning("Worker process did not terminate in time, sending SIGKILL")
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
llm_engine.worker_proc.wait(timeout=2)
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error terminating worker process: {e}")
|
||||
|
||||
if hasattr(llm_engine, "cache_manager_processes") and llm_engine.cache_manager_processes:
|
||||
for proc in llm_engine.cache_manager_processes:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.join(timeout=2)
|
||||
if proc.is_alive():
|
||||
proc.kill()
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error terminating cache manager process: {e}")
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error during cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""处理SIGINT和SIGTERM信号"""
|
||||
sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
|
||||
api_server_logger.info(f"Received {sig_name}, initiating graceful shutdown...")
|
||||
shutdown_event.set()
|
||||
cleanup_processes()
|
||||
# 不在这里退出,让gunicorn处理
|
||||
|
||||
|
||||
class StandaloneApplication(BaseApplication):
|
||||
def __init__(self, app, options=None):
|
||||
self.application = app
|
||||
@@ -98,9 +142,31 @@ class StandaloneApplication(BaseApplication):
|
||||
config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None}
|
||||
for key, value in config.items():
|
||||
self.cfg.set(key.lower(), value)
|
||||
|
||||
|
||||
def load(self):
|
||||
return self.application
|
||||
|
||||
def run(self):
|
||||
"""重写run方法以添加信号处理"""
|
||||
# 在主进程中设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
super().run()
|
||||
finally:
|
||||
# 确保退出时清理所有资源
|
||||
cleanup_processes()
|
||||
|
||||
# Gunicorn hooks
|
||||
def worker_exit(self, server, worker):
|
||||
"""当worker进程退出时调用"""
|
||||
api_server_logger.info(f"Worker {worker.pid} exiting")
|
||||
|
||||
def on_exit(self, server):
|
||||
"""当Arbiter退出时调用"""
|
||||
api_server_logger.info("Gunicorn master process exiting, cleaning up...")
|
||||
cleanup_processes()
|
||||
|
||||
|
||||
def load_engine():
|
||||
@@ -713,8 +779,8 @@ def launch_worker_monitor():
|
||||
"""
|
||||
|
||||
def _monitor():
|
||||
global llm_engine
|
||||
while True:
|
||||
global llm_engine, shutdown_event
|
||||
while not shutdown_event.is_set():
|
||||
if hasattr(llm_engine, "worker_proc") and llm_engine.worker_proc.poll() is not None:
|
||||
console_logger.error(
|
||||
f"Worker process has died in the background (code={llm_engine.worker_proc.returncode}). API server is forced to stop."
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@@ -1032,6 +1034,15 @@ def run_worker_proc() -> None:
|
||||
"""
|
||||
start worker process
|
||||
"""
|
||||
# 设置信号处理器以优雅退出
|
||||
def signal_handler(signum, frame):
|
||||
sig_name = "SIGINT" if signum == signal.SIGINT else "SIGTERM"
|
||||
logger.info(f"Worker process received {sig_name}, shutting down gracefully...")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Get args form Engine
|
||||
args = parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user