[Feature] Setting number of apiserver workers automatically (#3790)

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-09-02 14:17:48 +08:00
committed by GitHub
parent bf0cf5167a
commit 0e4df5a6f4
3 changed files with 21 additions and 2 deletions

View File

@@ -589,7 +589,7 @@ class EngineSevice:
else:
err, data = self.zmq_server.receive_pyobj_once(block)
if err is not None:
llm_logger.error("Engine stops inserting zmq task into scheduler, err:{err}")
llm_logger.error(f"Engine stops inserting zmq task into scheduler, err:{err}")
break
request, insert_task = None, []

View File

@@ -60,6 +60,7 @@ from fastdeploy.utils import (
StatefulSemaphore,
api_server_logger,
console_logger,
is_package_installed,
is_port_available,
retrive_model_from_server,
)
@@ -67,7 +68,7 @@ from fastdeploy.utils import (
parser = FlexibleArgumentParser()
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=1, type=int, help="number of workers")
parser.add_argument("--workers", default=None, type=int, help="number of workers")
parser.add_argument("--metrics-port", default=8001, type=int, help="port for metrics server")
parser.add_argument("--controller-port", default=-1, type=int, help="port for controller server")
parser.add_argument(
@@ -82,6 +83,15 @@ parser.add_argument(
)
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.workers is None:
# In GPU, the workers of uvicorn will be set according to the parameter `max-num-seqs`
if is_package_installed("paddlepaddle-gpu"):
args.workers = max(min(int(args.max_num_seqs // 32), 8), 1)
else:
args.workers = 1
console_logger.info(f"Number of api-server workers: {args.workers}.")
args.model = retrive_model_from_server(args.model, args.revision)
chat_template = load_chat_template(args.chat_template, args.model)
if args.tool_parser_plugin:

View File

@@ -27,6 +27,7 @@ import sys
import tarfile
import time
from datetime import datetime
from importlib.metadata import PackageNotFoundError, distribution
from logging.handlers import BaseRotatingHandler
from pathlib import Path
from typing import Literal, TypeVar, Union
@@ -668,6 +669,14 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
return module
def is_package_installed(package_name):
try:
distribution(package_name)
return True
except PackageNotFoundError:
return False
def version():
"""
Prints the contents of the version.txt file located in the parent directory of this script.