Files
FastDeploy/fastdeploy/entrypoints/openai/run_batch.py
xiaolei373 55124f8491 Add cli run batch (#4237)
* feat(log):add_request_and_response_log

* [cli] add run batch cli

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-26 14:27:25 +08:00

547 lines
19 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/run_batch.py
import argparse
import asyncio
import os
import tempfile
import uuid
from argparse import Namespace
from collections.abc import Awaitable
from http import HTTPStatus
from io import StringIO
from multiprocessing import current_process
from typing import Callable, List, Optional, Tuple
import aiohttp
import zmq
from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ErrorResponse,
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
console_logger,
retrive_model_from_server,
)
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n"
llm_engine = None
engine_client = None
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help="The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.",
)
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.",
)
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it " "to the output URL.",
)
parser.add_argument(
"--max-waiting-time",
default=-1,
type=int,
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
# parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=None, type=int, help="number of workers")
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser.add_argument(
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
)
parser = EngineArgs.add_cli_args(parser)
return parser
def parse_args():
parser = FlexibleArgumentParser(description="FastDeploy OpenAI-Compatible batch runner.")
args = make_arg_parser(parser).parse_args()
return args
def init_engine(args: argparse.Namespace):
"""
load engine
"""
global llm_engine
if llm_engine is not None:
return llm_engine
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return None
llm_engine = engine
return engine
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._completed = 0
self._pbar: Optional[tqdm] = None
self._last_log_count = 0
def submitted(self):
self._total += 1
def completed(self):
self._completed += 1
if self._pbar:
self._pbar.update()
if self._total > 0:
log_interval = min(100, max(self._total // 10, 1))
if self._completed - self._last_log_count >= log_interval:
console_logger.info(f"Progress: {self._completed}/{self._total} requests completed")
self._last_log_count = self._completed
def pbar(self) -> tqdm:
self._pbar = tqdm(
total=self._total,
unit="req",
desc="Running batch",
mininterval=10,
bar_format=_BAR_FORMAT,
)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(output_path: str, batch_outputs: list[BatchRequestOutput]) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't effect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1000)) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url, data=file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
else:
async with session.put(output_url, data=data_or_file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
except Exception as e:
if attempt < max_retries:
console_logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
console_logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
console_logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
console_logger.info("Writing outputs to temporary local file %s", f.name)
await write_local_file(f.name, batch_outputs)
console_logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
console_logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def make_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"fastdeploy-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(
serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker,
semaphore: asyncio.Semaphore,
) -> BatchRequestOutput:
async with semaphore:
try:
response = await serving_engine_func(request.body)
if isinstance(response, ChatCompletionResponse):
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=200, body=response, request_id=f"fastdeploy-batch-{random_uuid()}"
),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(status_code=400, request_id=f"fastdeploy-batch-{random_uuid()}"),
error=response,
)
else:
batch_output = make_error_request_output(request, error_msg="Request must not be sent in stream mode")
tracker.completed()
return batch_output
except Exception as e:
console_logger.error(f"Request {request.custom_id} processing failed: {str(e)}")
tracker.completed()
return make_error_request_output(request, error_msg=f"Request processing failed: {str(e)}")
def determine_process_id() -> int:
"""Determine the appropriate process ID."""
if current_process().name != "MainProcess":
return os.getppid()
else:
return os.getpid()
def create_model_paths(args: Namespace) -> List[ModelPath]:
"""Create model paths configuration."""
if args.served_model_name is not None:
served_model_names = args.served_model_name
verification = True
else:
served_model_names = args.model
verification = False
return [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
async def initialize_engine_client(args: Namespace, pid: int) -> EngineClient:
"""Initialize and configure the engine client."""
engine_client = EngineClient(
model_name_or_path=args.model,
tokenizer=args.tokenizer,
max_model_len=args.max_model_len,
tensor_parallel_size=args.tensor_parallel_size,
pid=pid,
port=int(args.engine_worker_queue_port[args.local_data_parallel_id]),
limit_mm_per_prompt=args.limit_mm_per_prompt,
mm_processor_kwargs=args.mm_processor_kwargs,
reasoning_parser=args.reasoning_parser,
data_parallel_size=args.data_parallel_size,
enable_logprob=args.enable_logprob,
workers=args.workers,
tool_parser=args.tool_call_parser,
)
await engine_client.connection_manager.initialize()
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
return engine_client
def create_serving_handlers(
args: Namespace, engine_client: EngineClient, model_paths: List[ModelPath], chat_template: str, pid: int
) -> OpenAIServingChat:
"""Create model and chat serving handlers."""
model_handler = OpenAIServingModels(
model_paths,
args.max_model_len,
args.ips,
)
chat_handler = OpenAIServingChat(
engine_client,
model_handler,
pid,
args.ips,
args.max_waiting_time,
chat_template,
args.enable_mm_output,
args.tokenizer_base_url,
)
return chat_handler
async def setup_engine_and_handlers(args: Namespace) -> Tuple[EngineClient, OpenAIServingChat]:
"""Setup engine client and all necessary handlers."""
if args.tokenizer is None:
args.tokenizer = args.model
pid = determine_process_id()
console_logger.info(f"Process ID: {pid}")
model_paths = create_model_paths(args)
chat_template = load_chat_template(args.chat_template, args.model)
# Initialize engine client
engine_client = await initialize_engine_client(args, pid)
engine_client = engine_client
# Create handlers
chat_handler = create_serving_handlers(args, engine_client, model_paths, chat_template, pid)
# Update data processor if engine exists
if llm_engine is not None:
llm_engine.engine.data_processor = engine_client.data_processor
return engine_client, chat_handler
async def run_batch(
args: argparse.Namespace,
) -> None:
# Setup engine and handlers
engine_client, chat_handler = await setup_engine_and_handlers(args)
concurrency = getattr(args, "max_concurrency", 512)
workers = getattr(args, "workers", 1)
max_concurrency = (concurrency + workers - 1) // workers
semaphore = asyncio.Semaphore(max_concurrency)
console_logger.info(f"concurrency: {concurrency}, workers: {workers}, max_concurrency: {max_concurrency}")
tracker = BatchProgressTracker()
console_logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = chat_handler.create_chat_completion if chat_handler is not None else None
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Chat Completions API",
)
)
continue
response_futures.append(run_request(chat_handler_fn, request, tracker, semaphore))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions"
"See fastdeploy/entrypoints/openai/api_server.py for supported "
"/v1/chat/completions versions.",
)
)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
success_count = sum(1 for r in responses if r.error is None)
error_count = len(responses) - success_count
console_logger.info(f"Batch processing completed: {success_count} success, {error_count} errors")
await write_file(args.output_file, responses, args.output_tmp_dir)
console_logger.info("Results written to output file")
async def main(args: argparse.Namespace):
console_logger.info("Starting batch runner with args: %s", args)
try:
if args.workers is None:
args.workers = max(min(int(args.max_num_seqs // 32), 8), 1)
args.model = retrive_model_from_server(args.model, args.revision)
if args.tool_parser_plugin:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if not init_engine(args):
return
await run_batch(args)
except Exception as e:
print("Fatal error in main:")
print(e)
console_logger.error(f"Fatal error in main: {e}", exc_info=True)
raise
finally:
await cleanup_resources()
async def cleanup_resources() -> None:
"""Clean up all resources during shutdown."""
try:
# stop engine
if llm_engine is not None:
try:
llm_engine._exit_sub_services()
except Exception as e:
console_logger.error(f"Error stopping engine: {e}")
# close client connections
if engine_client is not None:
try:
if hasattr(engine_client, "zmq_client"):
engine_client.zmq_client.close()
if hasattr(engine_client, "connection_manager"):
await engine_client.connection_manager.close()
except Exception as e:
console_logger.error(f"Error closing client connections: {e}")
# garbage collect
import gc
gc.collect()
print("run batch done")
except Exception as e:
console_logger.error(f"Error during cleanup: {e}")
if __name__ == "__main__":
args = parse_args()
asyncio.run(main(args=args))