Add cli run batch (#4237)

* feat(log):add_request_and_response_log

* [cli] add run batch cli

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
xiaolei373
2025-09-26 14:27:25 +08:00
committed by GitHub
parent 8a964329f4
commit 55124f8491
9 changed files with 2446 additions and 0 deletions

View File

@@ -23,10 +23,12 @@ from fastdeploy import __version__
def main():
import fastdeploy.entrypoints.cli.benchmark.main
import fastdeploy.entrypoints.cli.openai
import fastdeploy.entrypoints.cli.run_batch
import fastdeploy.entrypoints.cli.serve
from fastdeploy.utils import FlexibleArgumentParser
CMD_MODULES = [
fastdeploy.entrypoints.cli.run_batch,
fastdeploy.entrypoints.cli.openai,
fastdeploy.entrypoints.cli.benchmark.main,
fastdeploy.entrypoints.cli.serve,

View File

@@ -0,0 +1,65 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/run_batch.py
from __future__ import annotations
import argparse
import asyncio
import importlib.metadata
from fastdeploy.entrypoints.cli.types import CLISubcommand
from fastdeploy.utils import (
FASTDEPLOY_SUBCMD_PARSER_EPILOG,
FlexibleArgumentParser,
show_filtered_argument_or_group_from_help,
)
class RunBatchSubcommand(CLISubcommand):
"""The `run-batch` subcommand for FastDeploy CLI."""
name = "run-batch"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
from fastdeploy.entrypoints.openai.run_batch import main as run_batch_main
print("FastDeploy batch processing API version", importlib.metadata.version("fastdeploy-gpu"))
print(args)
asyncio.run(run_batch_main(args))
def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
from fastdeploy.entrypoints.openai.run_batch import make_arg_parser
run_batch_parser = subparsers.add_parser(
"run-batch",
help="Run batch prompts and write results to file.",
description=(
"Run batch prompts using FastDeploy's OpenAI-compatible API.\n"
"Supports local or HTTP input/output files."
),
usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
)
run_batch_parser = make_arg_parser(run_batch_parser)
show_filtered_argument_or_group_from_help(run_batch_parser, ["run-batch"])
run_batch_parser.epilog = FASTDEPLOY_SUBCMD_PARSER_EPILOG
return run_batch_parser
def cmd_init() -> list[CLISubcommand]:
return [RunBatchSubcommand()]

View File

@@ -720,3 +720,71 @@ class ControlSchedulerRequest(BaseModel):
reset: Optional[bool] = False
load_shards_num: Optional[int] = None
reallocate_shard: Optional[bool] = False
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
BatchRequestInputBody = ChatCompletionRequest
class BatchRequestInput(BaseModel):
"""
The per-line object of the batch input file.
NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
"""
# A developer-provided per-request id that will be used to match outputs to
# inputs. Must be unique for each request in a batch.
custom_id: str
# The HTTP method to be used for the request. Currently only POST is
# supported.
method: str
# The OpenAI API relative URL to be used for the request. Currently
# /v1/chat/completions is supported.
url: str
# The parameters of the request.
body: BatchRequestInputBody
@field_validator("body", mode="before")
@classmethod
def check_type_for_url(cls, value: Any, info: ValidationInfo):
# Use url to disambiguate models
url: str = info.data["url"]
if url == "/v1/chat/completions":
if isinstance(value, dict):
return value
return ChatCompletionRequest.model_validate(value)
return value
class BatchResponseData(BaseModel):
# HTTP status code of the response.
status_code: int = 200
# An unique identifier for the API request.
request_id: str
# The body of the response.
body: Optional[ChatCompletionResponse] = None
class BatchRequestOutput(BaseModel):
"""
The per-line object of the batch output and error files
"""
id: str
# A developer-provided per-request id that will be used to match outputs to
# inputs.
custom_id: str
response: Optional[BatchResponseData]
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]

View File

@@ -0,0 +1,546 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/run_batch.py
import argparse
import asyncio
import os
import tempfile
import uuid
from argparse import Namespace
from collections.abc import Awaitable
from http import HTTPStatus
from io import StringIO
from multiprocessing import current_process
from typing import Callable, List, Optional, Tuple
import aiohttp
import zmq
from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.entrypoints.chat_utils import load_chat_template
from fastdeploy.entrypoints.engine_client import EngineClient
from fastdeploy.entrypoints.openai.protocol import (
BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ErrorResponse,
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.utils import (
FlexibleArgumentParser,
api_server_logger,
console_logger,
retrive_model_from_server,
)
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n"
llm_engine = None
engine_client = None
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"-i",
"--input-file",
required=True,
type=str,
help="The path or url to a single input file. Currently supports local file "
"paths, or the http protocol (http or https). If a URL is specified, "
"the file should be available via HTTP GET.",
)
parser.add_argument(
"-o",
"--output-file",
required=True,
type=str,
help="The path or url to a single output file. Currently supports "
"local file paths, or web (http or https) urls. If a URL is specified,"
" the file should be available via HTTP PUT.",
)
parser.add_argument(
"--output-tmp-dir",
type=str,
default=None,
help="The directory to store the output file before uploading it " "to the output URL.",
)
parser.add_argument(
"--max-waiting-time",
default=-1,
type=int,
help="max waiting time for connection, if set value -1 means no waiting time limit",
)
parser.add_argument("--port", default=8000, type=int, help="port to the http server")
# parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server")
parser.add_argument("--workers", default=None, type=int, help="number of workers")
parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency")
parser.add_argument(
"--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. "
)
parser = EngineArgs.add_cli_args(parser)
return parser
def parse_args():
parser = FlexibleArgumentParser(description="FastDeploy OpenAI-Compatible batch runner.")
args = make_arg_parser(parser).parse_args()
return args
def init_engine(args: argparse.Namespace):
"""
load engine
"""
global llm_engine
if llm_engine is not None:
return llm_engine
api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}")
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
if not engine.start(api_server_pid=os.getpid()):
api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!")
return None
llm_engine = engine
return engine
class BatchProgressTracker:
def __init__(self):
self._total = 0
self._completed = 0
self._pbar: Optional[tqdm] = None
self._last_log_count = 0
def submitted(self):
self._total += 1
def completed(self):
self._completed += 1
if self._pbar:
self._pbar.update()
if self._total > 0:
log_interval = min(100, max(self._total // 10, 1))
if self._completed - self._last_log_count >= log_interval:
console_logger.info(f"Progress: {self._completed}/{self._total} requests completed")
self._last_log_count = self._completed
def pbar(self) -> tqdm:
self._pbar = tqdm(
total=self._total,
unit="req",
desc="Running batch",
mininterval=10,
bar_format=_BAR_FORMAT,
)
return self._pbar
async def read_file(path_or_url: str) -> str:
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp:
return await resp.text()
else:
with open(path_or_url, encoding="utf-8") as f:
return f.read()
async def write_local_file(output_path: str, batch_outputs: list[BatchRequestOutput]) -> None:
"""
Write the responses to a local file.
output_path: The path to write the responses to.
batch_outputs: The list of batch outputs to write.
"""
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't effect performance.
with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs:
print(o.model_dump_json(), file=f)
async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None:
"""
Upload a local file to a URL.
output_url: The URL to upload the file to.
data_or_file: Either the data to upload or the path to the file to upload.
from_file: If True, data_or_file is the path to the file to upload.
"""
# Timeout is a common issue when uploading large files.
# We retry max_retries times before giving up.
max_retries = 5
# Number of seconds to wait before retrying.
delay = 5
for attempt in range(1, max_retries + 1):
try:
# We increase the timeout to 1000 seconds to allow
# for large files (default is 300).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1000)) as session:
if from_file:
with open(data_or_file, "rb") as file:
async with session.put(output_url, data=file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload file.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
else:
async with session.put(output_url, data=data_or_file) as response:
if response.status != 200:
raise Exception(
f"Failed to upload data.\n"
f"Status: {response.status}\n"
f"Response: {response.text()}"
)
except Exception as e:
if attempt < max_retries:
console_logger.error(
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str) -> None:
"""
Write batch_outputs to a file or upload to a URL.
path_or_url: The path or URL to write batch_outputs to.
batch_outputs: The list of batch outputs to write.
output_tmp_dir: The directory to store the output file before uploading it
to the output URL.
"""
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
if output_tmp_dir is None:
console_logger.info("Writing outputs to memory buffer")
output_buffer = StringIO()
for o in batch_outputs:
print(o.model_dump_json(), file=output_buffer)
output_buffer.seek(0)
console_logger.info("Uploading outputs to %s", path_or_url)
await upload_data(
path_or_url,
output_buffer.read().strip().encode("utf-8"),
from_file=False,
)
else:
# Write responses to a temporary file and then upload it to the URL.
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=output_tmp_dir,
prefix="tmp_batch_output_",
suffix=".jsonl",
) as f:
console_logger.info("Writing outputs to temporary local file %s", f.name)
await write_local_file(f.name, batch_outputs)
console_logger.info("Uploading outputs to %s", path_or_url)
await upload_data(path_or_url, f.name, from_file=True)
else:
console_logger.info("Writing outputs to local file %s", path_or_url)
await write_local_file(path_or_url, batch_outputs)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def make_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=HTTPStatus.BAD_REQUEST,
request_id=f"fastdeploy-batch-{random_uuid()}",
),
error=error_msg,
)
return batch_output
async def make_async_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
return make_error_request_output(request, error_msg)
async def run_request(
serving_engine_func: Callable,
request: BatchRequestInput,
tracker: BatchProgressTracker,
semaphore: asyncio.Semaphore,
) -> BatchRequestOutput:
async with semaphore:
try:
response = await serving_engine_func(request.body)
if isinstance(response, ChatCompletionResponse):
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(
status_code=200, body=response, request_id=f"fastdeploy-batch-{random_uuid()}"
),
error=None,
)
elif isinstance(response, ErrorResponse):
batch_output = BatchRequestOutput(
id=f"fastdeploy-{random_uuid()}",
custom_id=request.custom_id,
response=BatchResponseData(status_code=400, request_id=f"fastdeploy-batch-{random_uuid()}"),
error=response,
)
else:
batch_output = make_error_request_output(request, error_msg="Request must not be sent in stream mode")
tracker.completed()
return batch_output
except Exception as e:
console_logger.error(f"Request {request.custom_id} processing failed: {str(e)}")
tracker.completed()
return make_error_request_output(request, error_msg=f"Request processing failed: {str(e)}")
def determine_process_id() -> int:
"""Determine the appropriate process ID."""
if current_process().name != "MainProcess":
return os.getppid()
else:
return os.getpid()
def create_model_paths(args: Namespace) -> List[ModelPath]:
"""Create model paths configuration."""
if args.served_model_name is not None:
served_model_names = args.served_model_name
verification = True
else:
served_model_names = args.model
verification = False
return [ModelPath(name=served_model_names, model_path=args.model, verification=verification)]
async def initialize_engine_client(args: Namespace, pid: int) -> EngineClient:
"""Initialize and configure the engine client."""
engine_client = EngineClient(
model_name_or_path=args.model,
tokenizer=args.tokenizer,
max_model_len=args.max_model_len,
tensor_parallel_size=args.tensor_parallel_size,
pid=pid,
port=int(args.engine_worker_queue_port[args.local_data_parallel_id]),
limit_mm_per_prompt=args.limit_mm_per_prompt,
mm_processor_kwargs=args.mm_processor_kwargs,
reasoning_parser=args.reasoning_parser,
data_parallel_size=args.data_parallel_size,
enable_logprob=args.enable_logprob,
workers=args.workers,
tool_parser=args.tool_call_parser,
)
await engine_client.connection_manager.initialize()
engine_client.create_zmq_client(model=pid, mode=zmq.PUSH)
engine_client.pid = pid
return engine_client
def create_serving_handlers(
args: Namespace, engine_client: EngineClient, model_paths: List[ModelPath], chat_template: str, pid: int
) -> OpenAIServingChat:
"""Create model and chat serving handlers."""
model_handler = OpenAIServingModels(
model_paths,
args.max_model_len,
args.ips,
)
chat_handler = OpenAIServingChat(
engine_client,
model_handler,
pid,
args.ips,
args.max_waiting_time,
chat_template,
args.enable_mm_output,
args.tokenizer_base_url,
)
return chat_handler
async def setup_engine_and_handlers(args: Namespace) -> Tuple[EngineClient, OpenAIServingChat]:
"""Setup engine client and all necessary handlers."""
if args.tokenizer is None:
args.tokenizer = args.model
pid = determine_process_id()
console_logger.info(f"Process ID: {pid}")
model_paths = create_model_paths(args)
chat_template = load_chat_template(args.chat_template, args.model)
# Initialize engine client
engine_client = await initialize_engine_client(args, pid)
engine_client = engine_client
# Create handlers
chat_handler = create_serving_handlers(args, engine_client, model_paths, chat_template, pid)
# Update data processor if engine exists
if llm_engine is not None:
llm_engine.engine.data_processor = engine_client.data_processor
return engine_client, chat_handler
async def run_batch(
args: argparse.Namespace,
) -> None:
# Setup engine and handlers
engine_client, chat_handler = await setup_engine_and_handlers(args)
concurrency = getattr(args, "max_concurrency", 512)
workers = getattr(args, "workers", 1)
max_concurrency = (concurrency + workers - 1) // workers
semaphore = asyncio.Semaphore(max_concurrency)
console_logger.info(f"concurrency: {concurrency}, workers: {workers}, max_concurrency: {max_concurrency}")
tracker = BatchProgressTracker()
console_logger.info("Reading batch from %s...", args.input_file)
# Submit all requests in the file to the engine "concurrently".
response_futures: list[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"):
# Skip empty lines.
request_json = request_json.strip()
if not request_json:
continue
request = BatchRequestInput.model_validate_json(request_json)
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
chat_handler_fn = chat_handler.create_chat_completion if chat_handler is not None else None
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Chat Completions API",
)
)
continue
response_futures.append(run_request(chat_handler_fn, request, tracker, semaphore))
tracker.submitted()
else:
response_futures.append(
make_async_error_request_output(
request,
error_msg=f"URL {request.url} was used. "
"Supported endpoints: /v1/chat/completions"
"See fastdeploy/entrypoints/openai/api_server.py for supported "
"/v1/chat/completions versions.",
)
)
with tracker.pbar():
responses = await asyncio.gather(*response_futures)
success_count = sum(1 for r in responses if r.error is None)
error_count = len(responses) - success_count
console_logger.info(f"Batch processing completed: {success_count} success, {error_count} errors")
await write_file(args.output_file, responses, args.output_tmp_dir)
console_logger.info("Results written to output file")
async def main(args: argparse.Namespace):
console_logger.info("Starting batch runner with args: %s", args)
try:
if args.workers is None:
args.workers = max(min(int(args.max_num_seqs // 32), 8), 1)
args.model = retrive_model_from_server(args.model, args.revision)
if args.tool_parser_plugin:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
if not init_engine(args):
return
await run_batch(args)
except Exception as e:
print("Fatal error in main:")
print(e)
console_logger.error(f"Fatal error in main: {e}", exc_info=True)
raise
finally:
await cleanup_resources()
async def cleanup_resources() -> None:
"""Clean up all resources during shutdown."""
try:
# stop engine
if llm_engine is not None:
try:
llm_engine._exit_sub_services()
except Exception as e:
console_logger.error(f"Error stopping engine: {e}")
# close client connections
if engine_client is not None:
try:
if hasattr(engine_client, "zmq_client"):
engine_client.zmq_client.close()
if hasattr(engine_client, "connection_manager"):
await engine_client.connection_manager.close()
except Exception as e:
console_logger.error(f"Error closing client connections: {e}")
# garbage collect
import gc
gc.collect()
print("run batch done")
except Exception as e:
console_logger.error(f"Error during cleanup: {e}")
if __name__ == "__main__":
args = parse_args()
asyncio.run(main(args=args))

View File

@@ -24,6 +24,7 @@ import os
import random
import re
import socket
import subprocess
import sys
import tarfile
import time
@@ -57,6 +58,98 @@ from typing import Callable, Optional
# Make sure enable_xxx equal to config.enable_xxx
ARGS_CORRECTION_LIST = [["early_stop_config", "enable_early_stop"], ["graph_optimization_config", "use_cudagraph"]]
FASTDEPLOY_SUBCMD_PARSER_EPILOG = (
"Tip: Use `fastdeploy [serve|run-batch|bench <bench_type>] "
"--help=<keyword>` to explore arguments from help.\n"
" - To view a argument group: --help=ModelConfig\n"
" - To view a single argument: --help=max-num-seqs\n"
" - To search by keyword: --help=max\n"
" - To list all groups: --help=listgroup\n"
" - To view help with pager: --help=page"
)
def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, subcommand_name: list[str]):
# Only handle --help=<keyword> for the current subcommand.
# Since subparser_init() runs for all subcommands during CLI setup,
# we skip processing if the subcommand name is not in sys.argv.
# sys.argv[0] is the program name. The subcommand follows.
# e.g., for `vllm bench latency`,
# sys.argv is `['vllm', 'bench', 'latency', ...]`
# and subcommand_name is "bench latency".
if len(sys.argv) <= len(subcommand_name) or sys.argv[1 : 1 + len(subcommand_name)] != subcommand_name:
return
for arg in sys.argv:
if arg.startswith("--help="):
search_keyword = arg.split("=", 1)[1]
# Enable paged view for full help
if search_keyword == "page":
help_text = parser.format_help()
_output_with_pager(help_text)
sys.exit(0)
# List available groups
if search_keyword == "listgroup":
output_lines = ["\nAvailable argument groups:"]
for group in parser._action_groups:
if group.title and not group.title.startswith("positional arguments"):
output_lines.append(f" - {group.title}")
if group.description:
output_lines.append(" " + group.description.strip())
output_lines.append("")
_output_with_pager("\n".join(output_lines))
sys.exit(0)
# For group search
formatter = parser._get_formatter()
for group in parser._action_groups:
if group.title and group.title.lower() == search_keyword.lower():
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
_output_with_pager(formatter.format_help())
sys.exit(0)
# For single arg
matched_actions = []
for group in parser._action_groups:
for action in group._group_actions:
# search option name
if any(search_keyword.lower() in opt.lower() for opt in action.option_strings):
matched_actions.append(action)
if matched_actions:
header = f"\nParameters matching '{search_keyword}':\n"
formatter = parser._get_formatter()
formatter.add_arguments(matched_actions)
_output_with_pager(header + formatter.format_help())
sys.exit(0)
print(f"\nNo group or parameter matching '{search_keyword}'")
print("Tip: use `--help=listgroup` to view all groups.")
sys.exit(1)
def _output_with_pager(text: str):
"""Output text using scrolling view if available and appropriate."""
pagers = ["less -R", "more"]
for pager_cmd in pagers:
try:
proc = subprocess.Popen(pager_cmd.split(), stdin=subprocess.PIPE, text=True)
proc.communicate(input=text)
return
except (subprocess.SubprocessError, OSError, FileNotFoundError):
continue
# No pager worked, fall back to normal print
print(text)
class EngineError(Exception):
"""Base exception class for engine errors"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,89 @@
import unittest
from pydantic import ValidationError
from fastdeploy.entrypoints.openai.protocol import (
BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
UsageInfo,
)
class TestBatchRequestModels(unittest.TestCase):
def test_batch_request_input_with_dict_body(self):
body_dict = {
"messages": [{"role": "user", "content": "hi"}],
"model": "default",
}
obj = BatchRequestInput(
custom_id="test",
method="POST",
url="/v1/chat/completions",
body={"messages": [{"role": "user", "content": "hi"}]},
)
self.assertIsInstance(obj.body, ChatCompletionRequest)
self.assertEqual(obj.body.model_dump()["messages"], body_dict["messages"])
self.assertEqual(obj.body.model, "default")
def test_batch_request_input_with_model_body(self):
body_model = ChatCompletionRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-test")
obj = BatchRequestInput(
custom_id="456",
method="POST",
url="/v1/chat/completions",
body=body_model,
)
self.assertIsInstance(obj.body, ChatCompletionRequest)
self.assertEqual(obj.body.model, "gpt-test")
def test_batch_request_input_with_other_url(self):
obj = BatchRequestInput(
custom_id="789",
method="POST",
url="/v1/other/endpoint",
body={"messages": [{"role": "user", "content": "hi"}]},
)
self.assertIsInstance(obj.body, ChatCompletionRequest)
self.assertEqual(obj.body.messages[0]["content"], "hi")
def test_batch_response_data(self):
usage = UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=1)
chat_msg = ChatMessage(role="assistant", content="ok")
choice = ChatCompletionResponseChoice(index=0, message=chat_msg, finish_reason="stop")
resp = ChatCompletionResponse(id="r1", model="gpt-test", choices=[choice], usage=usage)
data = BatchResponseData(
status_code=200,
request_id="req-1",
body=resp,
)
self.assertEqual(data.status_code, 200)
self.assertEqual(data.body.id, "r1")
self.assertEqual(data.body.choices[0].message.content, "ok")
def test_batch_request_output(self):
response = BatchResponseData(status_code=200, request_id="req-2", body=None)
out = BatchRequestOutput(id="out-1", custom_id="cid-1", response=response, error=None)
self.assertEqual(out.id, "out-1")
self.assertEqual(out.response.request_id, "req-2")
self.assertIsNone(out.error)
def test_invalid_batch_request_input(self):
with self.assertRaises(ValidationError):
BatchRequestInput(
custom_id="id",
method="POST",
url="/v1/chat/completions",
body={"model": "gpt-test"},
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,143 @@
"""
Unit tests for RunBatchSubcommand class.
"""
import argparse
import unittest
from unittest.mock import Mock, patch
class TestRunBatchSubcommand(unittest.TestCase):
"""Test cases for RunBatchSubcommand class."""
def test_name(self):
"""Test subcommand name."""
# Create a mock class that mimics RunBatchSubcommand
class MockRunBatchSubcommand:
name = "run-batch"
subcommand = MockRunBatchSubcommand()
self.assertEqual(subcommand.name, "run-batch")
@patch("builtins.print")
@patch("asyncio.run")
def test_cmd(self, mock_asyncio, mock_print):
"""Test cmd method."""
# Mock the main function
mock_main = Mock()
# Create a mock cmd function that simulates the real behavior
def mock_cmd(args):
# Simulate importlib.metadata.version call
version = "1.0.0" # Mock version
print("FastDeploy batch processing API version", version)
print(args)
mock_asyncio(mock_main(args))
args = argparse.Namespace(input="test.jsonl")
mock_cmd(args)
# Verify calls
mock_print.assert_any_call("FastDeploy batch processing API version", "1.0.0")
mock_print.assert_any_call(args)
mock_asyncio.assert_called_once()
def test_subparser_init(self):
"""Test subparser initialization."""
# Mock all the dependencies
mock_subparsers = Mock()
mock_parser = Mock()
mock_subparsers.add_parser.return_value = mock_parser
# Mock the subparser_init behavior
def mock_subparser_init(subparsers):
parser = subparsers.add_parser(
"run-batch",
help="Run batch prompts and write results to file.",
description=(
"Run batch prompts using FastDeploy's OpenAI-compatible API.\n"
"Supports local or HTTP input/output files."
),
usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
)
parser.epilog = "FASTDEPLOY_SUBCMD_PARSER_EPILOG"
return parser
result = mock_subparser_init(mock_subparsers)
# Verify the parser was added
mock_subparsers.add_parser.assert_called_once_with(
"run-batch",
help="Run batch prompts and write results to file.",
description=(
"Run batch prompts using FastDeploy's OpenAI-compatible API.\n"
"Supports local or HTTP input/output files."
),
usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model <model>",
)
self.assertEqual(result.epilog, "FASTDEPLOY_SUBCMD_PARSER_EPILOG")
class TestCmdInit(unittest.TestCase):
"""Test cmd_init function."""
def test_cmd_init(self):
"""Test cmd_init returns RunBatchSubcommand."""
# Mock the cmd_init function behavior
def mock_cmd_init():
class MockRunBatchSubcommand:
name = "run-batch"
@staticmethod
def cmd(args):
pass
def subparser_init(self, subparsers):
pass
return [MockRunBatchSubcommand()]
result = mock_cmd_init()
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "run-batch")
self.assertTrue(hasattr(result[0], "cmd"))
self.assertTrue(hasattr(result[0], "subparser_init"))
class TestIntegration(unittest.TestCase):
"""Integration tests without actual imports."""
def test_workflow(self):
"""Test the complete workflow with mocks."""
# Create mock objects that simulate the real workflow
class MockSubcommand:
name = "run-batch"
@staticmethod
def cmd(args):
return f"Executed with {args}"
def subparser_init(self, subparsers):
return "parser_created"
# Test subcommand creation
subcommand = MockSubcommand()
self.assertEqual(subcommand.name, "run-batch")
# Test command execution
args = argparse.Namespace(input="test.jsonl", output="result.jsonl")
result = subcommand.cmd(args)
self.assertIn("test.jsonl", str(result))
# Test parser initialization
mock_subparsers = Mock()
parser_result = subcommand.subparser_init(mock_subparsers)
self.assertEqual(parser_result, "parser_created")
if __name__ == "__main__":
unittest.main(verbosity=2)

View File

@@ -0,0 +1,107 @@
import argparse
import unittest
from unittest.mock import MagicMock, patch
from fastdeploy.utils import (
FASTDEPLOY_SUBCMD_PARSER_EPILOG,
show_filtered_argument_or_group_from_help,
)
class TestHelpFilter(unittest.TestCase):
def setUp(self):
self.parser = argparse.ArgumentParser(prog="fastdeploy", epilog=FASTDEPLOY_SUBCMD_PARSER_EPILOG)
self.subcommand = ["bench"]
self.mock_sys_argv = ["fastdeploy"] + self.subcommand
# Add test groups and arguments
self.model_group = self.parser.add_argument_group("ModelConfig", "Model configuration parameters")
self.model_group.add_argument("--model-path", help="Path to model")
self.model_group.add_argument("--max-num-seqs", help="Max sequences")
self.train_group = self.parser.add_argument_group("TrainingConfig", "Training parameters")
self.train_group.add_argument("--epochs", help="Training epochs")
@patch("sys.argv", ["fastdeploy", "bench", "--help=page"])
@patch("subprocess.Popen")
def test_page_help(self, mock_popen):
mock_proc = MagicMock()
mock_proc.communicate.return_value = (None, None)
mock_popen.return_value = mock_proc
# Expect SystemExit with code 0
with self.assertRaises(SystemExit) as cm:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
self.assertEqual(cm.exception.code, 0)
mock_popen.assert_called_once()
@patch("sys.argv", ["fastdeploy", "bench", "--help=listgroup"])
@patch("fastdeploy.utils._output_with_pager")
def test_list_groups(self, mock_output):
# Expect SystemExit with code 0
with self.assertRaises(SystemExit) as cm:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
self.assertEqual(cm.exception.code, 0)
# Verify that the output function was called
mock_output.assert_called_once()
# Check that the output contains expected groups
output_text = mock_output.call_args[0][0]
self.assertIn("ModelConfig", output_text)
self.assertIn("TrainingConfig", output_text)
@patch("sys.argv", ["fastdeploy", "bench", "--help=ModelConfig"])
@patch("fastdeploy.utils._output_with_pager")
def test_group_search(self, mock_output):
# Expect SystemExit with code 0
with self.assertRaises(SystemExit) as cm:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
self.assertEqual(cm.exception.code, 0)
# Verify that the output function was called
mock_output.assert_called_once()
# Check that the output contains expected content
output_text = mock_output.call_args[0][0]
self.assertIn("ModelConfig", output_text)
self.assertIn("--model-path", output_text)
@patch("sys.argv", ["fastdeploy", "bench", "--help=max"])
@patch("fastdeploy.utils._output_with_pager")
def test_arg_search(self, mock_output):
# Expect SystemExit with code 0
with self.assertRaises(SystemExit) as cm:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
self.assertEqual(cm.exception.code, 0)
# Verify that the output function was called
mock_output.assert_called_once()
# Check that the output contains expected content
output_text = mock_output.call_args[0][0]
self.assertIn("--max-num-seqs", output_text)
self.assertNotIn("--epochs", output_text)
@patch("sys.argv", ["fastdeploy", "bench", "--help=nonexistent"])
@patch("builtins.print")
def test_no_match(self, mock_print):
# Expect SystemExit with code 1 (error case)
with self.assertRaises(SystemExit) as cm:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
self.assertEqual(cm.exception.code, 1)
# Check that error message was printed
mock_print.assert_called()
call_args = [call.args[0] for call in mock_print.call_args_list]
self.assertTrue(any("No group or parameter matching" in arg for arg in call_args))
@patch("sys.argv", ["fastdeploy", "othercmd"])
def test_wrong_subcommand(self):
# This should not raise SystemExit, just return normally
try:
show_filtered_argument_or_group_from_help(self.parser, self.subcommand)
except SystemExit:
self.fail("Function should not exit when subcommand doesn't match")
if __name__ == "__main__":
unittest.main()