mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[BugFix] fix too many open files problem (#3275)
This commit is contained in:
@@ -24,7 +24,7 @@ from fastdeploy.input.preprocess import InputPreprocessor
|
||||
from fastdeploy.inter_communicator import IPCSignal, ZmqClient
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
from fastdeploy.platforms import current_platform
|
||||
from fastdeploy.utils import EngineError, api_server_logger
|
||||
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
|
||||
|
||||
|
||||
class EngineClient:
|
||||
@@ -44,6 +44,7 @@ class EngineClient:
|
||||
reasoning_parser=None,
|
||||
data_parallel_size=1,
|
||||
enable_logprob=False,
|
||||
workers=1,
|
||||
):
|
||||
input_processor = InputPreprocessor(
|
||||
tokenizer,
|
||||
@@ -76,6 +77,7 @@ class EngineClient:
|
||||
suffix=pid,
|
||||
create=False,
|
||||
)
|
||||
self.semaphore = StatefulSemaphore((envs.FD_SUPPORT_MAX_CONNECTIONS + workers - 1) // workers)
|
||||
|
||||
def create_zmq_client(self, model, mode):
|
||||
"""
|
||||
|
@@ -14,15 +14,17 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from multiprocessing import current_process
|
||||
|
||||
import uvicorn
|
||||
import zmq
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import CONTENT_TYPE_LATEST
|
||||
|
||||
@@ -48,6 +50,7 @@ from fastdeploy.metrics.metrics import (
|
||||
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
|
||||
from fastdeploy.utils import (
|
||||
FlexibleArgumentParser,
|
||||
StatefulSemaphore,
|
||||
api_server_logger,
|
||||
console_logger,
|
||||
is_port_available,
|
||||
@@ -60,6 +63,13 @@ parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the htt
|
||||
parser.add_argument("--workers", default=1, type=int, help="number of workers")
|
||||
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
|
||||
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
|
||||
parser.add_argument(
|
||||
"--max-waiting-time",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="max waiting time for connection, if set value -1 means no waiting time limit",
|
||||
)
|
||||
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
@@ -115,10 +125,11 @@ async def lifespan(app: FastAPI):
|
||||
args.reasoning_parser,
|
||||
args.data_parallel_size,
|
||||
args.enable_logprob,
|
||||
args.workers,
|
||||
)
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips)
|
||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips)
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
|
||||
completion_handler = OpenAIServingCompletion(engine_client, pid, args.ips, args.max_waiting_time)
|
||||
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
|
||||
engine_client.pid = pid
|
||||
app.state.engine_client = engine_client
|
||||
@@ -140,6 +151,41 @@ app = FastAPI(lifespan=lifespan)
|
||||
instrument(app)
|
||||
|
||||
|
||||
MAX_CONCURRENT_CONNECTIONS = (args.max_concurrency + args.workers - 1) // args.workers
|
||||
connection_semaphore = StatefulSemaphore(MAX_CONCURRENT_CONNECTIONS)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection_manager():
|
||||
"""
|
||||
async context manager for connection manager
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(connection_semaphore.acquire(), timeout=0.001)
|
||||
yield
|
||||
except asyncio.TimeoutError:
|
||||
api_server_logger.info(f"Reach max request release: {connection_semaphore.status()}")
|
||||
if connection_semaphore.locked():
|
||||
connection_semaphore.release()
|
||||
raise HTTPException(status_code=429, detail="Too many requests")
|
||||
|
||||
|
||||
def wrap_streaming_generator(original_generator: AsyncGenerator):
|
||||
"""
|
||||
Wrap an async generator to release the connection semaphore when the generator is finished.
|
||||
"""
|
||||
|
||||
async def wrapped_generator():
|
||||
try:
|
||||
async for chunk in original_generator:
|
||||
yield chunk
|
||||
finally:
|
||||
api_server_logger.debug(f"release: {connection_semaphore.status()}")
|
||||
connection_semaphore.release()
|
||||
|
||||
return wrapped_generator
|
||||
|
||||
|
||||
# TODO 传递真实引擎值 通过pid 获取状态
|
||||
@app.get("/health")
|
||||
def health(request: Request) -> Response:
|
||||
@@ -202,16 +248,23 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||
status, msg = app.state.engine_client.is_workers_alive()
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
inject_to_metadata(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
try:
|
||||
async with connection_manager():
|
||||
inject_to_metadata(request)
|
||||
generator = await app.state.chat_handler.create_chat_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
wrapped_generator = wrap_streaming_generator(generator)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
except HTTPException as e:
|
||||
api_server_logger.error(f"Error in chat completion: {str(e)}")
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@@ -224,13 +277,20 @@ async def create_completion(request: CompletionRequest):
|
||||
if not status:
|
||||
return JSONResponse(content={"error": "Worker Service Not Healthy"}, status_code=304)
|
||||
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
try:
|
||||
async with connection_manager():
|
||||
generator = await app.state.completion_handler.create_completion(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
connection_semaphore.release()
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
else:
|
||||
wrapped_generator = wrap_streaming_generator(generator)
|
||||
return StreamingResponse(content=wrapped_generator(), media_type="text/event-stream")
|
||||
except HTTPException as e:
|
||||
return JSONResponse(status_code=e.status_code, content={"detail": e.detail})
|
||||
|
||||
|
||||
@app.get("/update_model_weight")
|
||||
|
@@ -49,10 +49,11 @@ class OpenAIServingChat:
|
||||
OpenAI-style chat completions serving
|
||||
"""
|
||||
|
||||
def __init__(self, engine_client, pid, ips):
|
||||
def __init__(self, engine_client, pid, ips, max_waiting_time):
|
||||
self.engine_client = engine_client
|
||||
self.pid = pid
|
||||
self.master_ip = ips
|
||||
self.max_waiting_time = max_waiting_time
|
||||
self.host_ip = get_host_ip()
|
||||
if self.master_ip is not None:
|
||||
if isinstance(self.master_ip, list):
|
||||
@@ -94,6 +95,15 @@ class OpenAIServingChat:
|
||||
|
||||
del current_req_dict
|
||||
|
||||
try:
|
||||
api_server_logger.debug(f"{self.engine_client.semaphore.status()}")
|
||||
if self.max_waiting_time < 0:
|
||||
await self.engine_client.semaphore.acquire()
|
||||
else:
|
||||
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
|
||||
except Exception:
|
||||
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(request, request_id, request.model, prompt_token_ids)
|
||||
else:
|
||||
@@ -310,6 +320,8 @@ class OpenAIServingChat:
|
||||
yield f"data: {error_data}\n\n"
|
||||
finally:
|
||||
dealer.close()
|
||||
self.engine_client.semaphore.release()
|
||||
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
@@ -384,6 +396,8 @@ class OpenAIServingChat:
|
||||
break
|
||||
finally:
|
||||
dealer.close()
|
||||
self.engine_client.semaphore.release()
|
||||
api_server_logger.info(f"release {self.engine_client.semaphore.status()}")
|
||||
|
||||
choices = []
|
||||
output = final_res["outputs"]
|
||||
|
@@ -40,11 +40,12 @@ from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
def __init__(self, engine_client, pid, ips):
|
||||
def __init__(self, engine_client, pid, ips, max_waiting_time):
|
||||
self.engine_client = engine_client
|
||||
self.pid = pid
|
||||
self.master_ip = ips
|
||||
self.host_ip = get_host_ip()
|
||||
self.max_waiting_time = max_waiting_time
|
||||
if self.master_ip is not None:
|
||||
if isinstance(self.master_ip, list):
|
||||
self.master_ip = self.master_ip[0]
|
||||
@@ -114,6 +115,14 @@ class OpenAIServingCompletion:
|
||||
|
||||
del current_req_dict
|
||||
|
||||
try:
|
||||
if self.max_waiting_time < 0:
|
||||
await self.engine_client.semaphore.acquire()
|
||||
else:
|
||||
await asyncio.wait_for(self.engine_client.semaphore.acquire(), timeout=self.max_waiting_time)
|
||||
except Exception:
|
||||
return ErrorResponse(code=408, message=f"Request queued time exceed {self.max_waiting_time}")
|
||||
|
||||
if request.stream:
|
||||
return self.completion_stream_generator(
|
||||
request=request,
|
||||
@@ -223,6 +232,7 @@ class OpenAIServingCompletion:
|
||||
finally:
|
||||
if dealer is not None:
|
||||
dealer.close()
|
||||
self.engine_client.semaphore.release()
|
||||
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
@@ -372,6 +382,7 @@ class OpenAIServingCompletion:
|
||||
del request
|
||||
if dealer is not None:
|
||||
dealer.close()
|
||||
self.engine_client.semaphore.release()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
def request_output_to_completion_response(
|
||||
|
@@ -82,6 +82,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"ENABLE_V1_KVCACHE_SCHEDULER": lambda: int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")),
|
||||
# set trace attribute job_id.
|
||||
"FD_JOB_ID": lambda: os.getenv("FD_JOB_ID"),
|
||||
# support max connections
|
||||
"FD_SUPPORT_MAX_CONNECTIONS": lambda: 768,
|
||||
}
|
||||
|
||||
|
||||
|
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import codecs
|
||||
import importlib
|
||||
import logging
|
||||
@@ -291,6 +292,16 @@ def extract_tar(tar_path, output_dir):
|
||||
raise RuntimeError(f"Extraction failed: {e!s}")
|
||||
|
||||
|
||||
def get_limited_max_value(max_value):
|
||||
def validator(value):
|
||||
value = float(value)
|
||||
if value > max_value:
|
||||
raise argparse.ArgumentTypeError(f"The value cannot exceed {max_value}")
|
||||
return value
|
||||
|
||||
return validator
|
||||
|
||||
|
||||
def download_model(url, output_dir, temp_tar):
|
||||
"""
|
||||
下载模型,并将其解压到指定目录。
|
||||
@@ -596,6 +607,61 @@ def version():
|
||||
return content
|
||||
|
||||
|
||||
class StatefulSemaphore:
|
||||
__slots__ = ("_semaphore", "_max_value", "_acquired_count", "_last_reset")
|
||||
|
||||
"""
|
||||
StatefulSemaphore is a class that wraps an asyncio.Semaphore and provides additional stateful information.
|
||||
"""
|
||||
|
||||
def __init__(self, value: int):
|
||||
"""
|
||||
StatefulSemaphore constructor
|
||||
"""
|
||||
if value < 0:
|
||||
raise ValueError("Value must be non-negative.")
|
||||
self._semaphore = asyncio.Semaphore(value)
|
||||
self._max_value = value
|
||||
self._acquired_count = 0
|
||||
self._last_reset = time.monotonic()
|
||||
|
||||
async def acquire(self):
|
||||
await self._semaphore.acquire()
|
||||
self._acquired_count += 1
|
||||
|
||||
def release(self):
|
||||
self._semaphore.release()
|
||||
|
||||
self._acquired_count = max(0, self._acquired_count - 1)
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self._semaphore.locked()
|
||||
|
||||
@property
|
||||
def available(self) -> int:
|
||||
return self._max_value - self._acquired_count
|
||||
|
||||
@property
|
||||
def acquired(self) -> int:
|
||||
return self._acquired_count
|
||||
|
||||
@property
|
||||
def max_value(self) -> int:
|
||||
return self._max_value
|
||||
|
||||
@property
|
||||
def uptime(self) -> float:
|
||||
return time.monotonic() - self._last_reset
|
||||
|
||||
def status(self) -> dict:
|
||||
return {
|
||||
"available": self.available,
|
||||
"acquired": self.acquired,
|
||||
"max_value": self.max_value,
|
||||
"uptime": round(self.uptime, 2),
|
||||
}
|
||||
|
||||
|
||||
llm_logger = get_logger("fastdeploy", "fastdeploy.log")
|
||||
data_processor_logger = get_logger("data_processor", "data_processor.log")
|
||||
scheduler_logger = get_logger("scheduler", "scheduler.log")
|
||||
|
Reference in New Issue
Block a user