[BugFix] fix too many open files problem (#3275)

This commit is contained in:
ltd0924
2025-08-08 20:11:32 +08:00
committed by GitHub
parent 1b6f482c15
commit 6706ccb37e
6 changed files with 177 additions and 22 deletions

View File

@@ -24,7 +24,7 @@ from fastdeploy.input.preprocess import InputPreprocessor
from fastdeploy.inter_communicator import IPCSignal, ZmqClient from fastdeploy.inter_communicator import IPCSignal, ZmqClient
from fastdeploy.metrics.work_metrics import work_process_metrics from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from fastdeploy.utils import EngineError, api_server_logger from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
class EngineClient: class EngineClient:
@@ -44,6 +44,7 @@ class EngineClient:
reasoning_parser=None, reasoning_parser=None,
data_parallel_size=1, data_parallel_size=1,
enable_logprob=False, enable_logprob=False,
workers=1,
): ):
input_processor = InputPreprocessor( input_processor = InputPreprocessor(
tokenizer, tokenizer,
@@ -76,6 +77,7 @@ class EngineClient:
suffix=pid, suffix=pid,
create=False, create=False,
) )
self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
def create_zmq_client(self, model, mode): def create_zmq_client(self, model, mode):
""" """

View File

@@ -14,15 +14,17 @@
# limitations under the License. # limitations under the License.
""" """
import asyncio
import os import os
import threading import threading
import time import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from multiprocessing import current_process from multiprocessing import current_process
import uvicorn import uvicorn
import zmq import zmq
from fastapi import FastAPI, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST from prometheus_client import CONTENT_TYPE_LATEST
@@ -48,6 +50,7 @@ from fastdeploy.metrics.metrics import (
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
from fastdeploy.utils import ( from fastdeploy.utils import (
FlexibleArgumentParser, FlexibleArgumentParser,
StatefulSemaphore,
api_server_logger, api_server_logger,
console_logger, console_logger,
is_port_available, is_port_available,
@@ -60,6 +63,13 @@ parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the htt
parser.add_argument("--workers", default=1, type=int, help="number of workers") parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server") parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server") parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser.add_argument(
"--max-waiting-time",
default=-1,
type=int,
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision) args.model = retrive_model_from_server(args.model, args.revision)
@@ -115,10 +125,11 @@ async def lifespan(app: FastAPI):
args.reasoning_parser, args.reasoning_parser,
args.data_parallel_size, args.data_parallel_size,
args.enable_logprob, args.enable_logprob,
args.workers,
) )
app.state.dynamic_load_weight = args.dynamic_load_weight app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips) chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips) completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid engine_client.pid = pid
app.state.engine_client = engine_client app.state.engine_client = engine_client
@@ -140,6 +151,41 @@ app = FastAPI(lifespan=lifespan)
instrument(app) instrument(app)
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
@asynccontextmanager
async def connection_manager():
"""
async context manager for connection manager
"""
try:
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
yield
except asyncio.TimeoutError:
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
if connection_semaphore.locked():
connection_semaphore.release()
raise HTTPException(status_code=429, detail="Too many requests")
def wrap_streaming_generator(original_generator: AsyncGenerator):
"""
Wrap an async generator to release the connection semaphore when the generator is finished.
"""
async def wrapped_generator():
try:
async for chunk in original_generator:
yield chunk
finally:
api_server_logger.debug(f"release: {connection_semaphore.status()}")
connection_semaphore.release()
return wrapped_generator
# TODO 传递真实引擎值 通过pid 获取状态 # TODO 传递真实引擎值 通过pid 获取状态
@app.get("/health") @app.get("/health")
def health(request: Request) -> Response: def health(request: Request) -> Response:
@@ -202,16 +248,23 @@ async def create_chat_completion(request: ChatCompletionRequest):
status, msg = app.state.engine_client.is_workers_alive() status, msg = app.state.engine_client.is_workers_alive()
if not status: if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
inject_to_metadata(request) try:
generator = await app.state.chat_handler.create_chat_completion(request) async with connection_manager():
inject_to_metadata(request)
generator = await app.state.chat_handler.create_chat_completion(request)
if isinstance(generator, ErrorResponse):
connection_semaphore.release()
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump())
else:
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
if isinstance(generator, ErrorResponse): except HTTPException as e:
return JSONResponse(content=generator.model_dump(), status_code=generator.code) api_server_logger.error(f"Error in chat completion: {str(e)}")
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@app.post("/v1/completions") @app.post("/v1/completions")
@@ -224,13 +277,20 @@ async def create_completion(request: CompletionRequest):
if not status: if not status:
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304) return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
generator = await app.state.completion_handler.create_completion(request) try:
if isinstance(generator, ErrorResponse): async with connection_manager():
return JSONResponse(content=generator.model_dump(), status_code=generator.code) generator = await app.state.completion_handler.create_completion(request)
elif isinstance(generator, CompletionResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump()) connection_semaphore.release()
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
return StreamingResponse(content=generator, media_type="text/event-stream") elif isinstance(generator, CompletionResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump())
else:
wrapped_generator = wrap_streaming_generator(generator)
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
except HTTPException as e:
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
@app.get("/update_model_weight") @app.get("/update_model_weight")

View File

@@ -49,10 +49,11 @@ class OpenAIServingChat:
OpenAI-style chat completions serving OpenAI-style chat completions serving
""" """
def __init__(self, engine_client, pid, ips): def __init__(self, engine_client, pid, ips, max_waiting_time):
self.engine_client = engine_client self.engine_client = engine_client
self.pid = pid self.pid = pid
self.master_ip = ips self.master_ip = ips
self.max_waiting_time = max_waiting_time
self.host_ip = get_host_ip() self.host_ip = get_host_ip()
if self.master_ip is not None: if self.master_ip is not None:
if isinstance(self.master_ip, list): if isinstance(self.master_ip, list):
@@ -94,6 +95,15 @@ class OpenAIServingChat:
del current_req_dict del current_req_dict
try:
api_server_logger.debug(f"{self.engine_client.semaphore.status()}")
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
if request.stream: if request.stream:
return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids) return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids)
else: else:
@@ -310,6 +320,8 @@ class OpenAIServingChat:
yield f"data: {error_data}\n\n" yield f"data: {error_data}\n\n"
finally: finally:
dealer.close() dealer.close()
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def chat_completion_full_generator( async def chat_completion_full_generator(
@@ -384,6 +396,8 @@ class OpenAIServingChat:
break break
finally: finally:
dealer.close() dealer.close()
self.engine_client.semaphore.release()
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
choices = [] choices = []
output = final_res["outputs"] output = final_res["outputs"]

View File

@@ -40,11 +40,12 @@ from fastdeploy.worker.output import LogprobsLists
class OpenAIServingCompletion: class OpenAIServingCompletion:
def __init__(self, engine_client, pid, ips): def __init__(self, engine_client, pid, ips, max_waiting_time):
self.engine_client = engine_client self.engine_client = engine_client
self.pid = pid self.pid = pid
self.master_ip = ips self.master_ip = ips
self.host_ip = get_host_ip() self.host_ip = get_host_ip()
self.max_waiting_time = max_waiting_time
if self.master_ip is not None: if self.master_ip is not None:
if isinstance(self.master_ip, list): if isinstance(self.master_ip, list):
self.master_ip = self.master_ip[0] self.master_ip = self.master_ip[0]
@@ -114,6 +115,14 @@ class OpenAIServingCompletion:
del current_req_dict del current_req_dict
try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
else:
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
except Exception:
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
if request.stream: if request.stream:
return self.completion_stream_generator( return self.completion_stream_generator(
request=request, request=request,
@@ -223,6 +232,7 @@ class OpenAIServingCompletion:
finally: finally:
if dealer is not None: if dealer is not None:
dealer.close() dealer.close()
self.engine_client.semaphore.release()
async def completion_stream_generator( async def completion_stream_generator(
self, self,
@@ -372,6 +382,7 @@ class OpenAIServingCompletion:
del request del request
if dealer is not None: if dealer is not None:
dealer.close() dealer.close()
self.engine_client.semaphore.release()
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
def request_output_to_completion_response( def request_output_to_completion_response(

View File

@@ -82,6 +82,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")), "ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
# set trace attribute job_id. # set trace attribute job_id.
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"), "FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
# support max connections
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
} }

View File

@@ -15,6 +15,7 @@
""" """
import argparse import argparse
import asyncio
import codecs import codecs
import importlib import importlib
import logging import logging
@@ -291,6 +292,16 @@ def extract_tar(tar_path, output_dir):
raise RuntimeError(f"Extraction failed: {e!s}") raise RuntimeError(f"Extraction failed: {e!s}")
def get_limited_max_value(max_value):
def validator(value):
value = float(value)
if value > max_value:
raise argparse.ArgumentTypeError(f"The value cannot exceed {max_value}")
return value
return validator
def download_model(url, output_dir, temp_tar): def download_model(url, output_dir, temp_tar):
""" """
下载模型,并将其解压到指定目录。 下载模型,并将其解压到指定目录。
@@ -596,6 +607,61 @@ def version():
return content return content
class StatefulSemaphore:
__slots__ = ("_semaphore", "_max_value", "_acquired_count", "_last_reset")
"""
StatefulSemaphore is a class that wraps an asyncio.Semaphore and provides additional stateful information.
"""
def __init__(self, value: int):
"""
StatefulSemaphore constructor
"""
if value < 0:
raise ValueError("Value must be non-negative.")
self._semaphore = asyncio.Semaphore(value)
self._max_value = value
self._acquired_count = 0
self._last_reset = time.monotonic()
async def acquire(self):
await self._semaphore.acquire()
self._acquired_count += 1
def release(self):
self._semaphore.release()
self._acquired_count = max(0, self._acquired_count - 1)
def locked(self) -> bool:
return self._semaphore.locked()
@property
def available(self) -> int:
return self._max_value - self._acquired_count
@property
def acquired(self) -> int:
return self._acquired_count
@property
def max_value(self) -> int:
return self._max_value
@property
def uptime(self) -> float:
return time.monotonic() - self._last_reset
def status(self) -> dict:
return {
"available": self.available,
"acquired": self.acquired,
"max_value": self.max_value,
"uptime": round(self.uptime, 2),
}
llm_logger = get_logger("fastdeploy", "fastdeploy.log") llm_logger = get_logger("fastdeploy", "fastdeploy.log")
data_processor_logger = get_logger("data_processor", "data_processor.log") data_processor_logger = get_logger("data_processor", "data_processor.log")
scheduler_logger = get_logger("scheduler", "scheduler.log") scheduler_logger = get_logger("scheduler", "scheduler.log")