diff --git a/fastdeploy/entrypoints/cli/main.py b/fastdeploy/entrypoints/cli/main.py index b770dc604..f71a94bea 100644 --- a/fastdeploy/entrypoints/cli/main.py +++ b/fastdeploy/entrypoints/cli/main.py @@ -23,10 +23,12 @@ from fastdeploy import __version__ def main(): import fastdeploy.entrypoints.cli.benchmark.main import fastdeploy.entrypoints.cli.openai + import fastdeploy.entrypoints.cli.run_batch import fastdeploy.entrypoints.cli.serve from fastdeploy.utils import FlexibleArgumentParser CMD_MODULES = [ + fastdeploy.entrypoints.cli.run_batch, fastdeploy.entrypoints.cli.openai, fastdeploy.entrypoints.cli.benchmark.main, fastdeploy.entrypoints.cli.serve, diff --git a/fastdeploy/entrypoints/cli/run_batch.py b/fastdeploy/entrypoints/cli/run_batch.py new file mode 100644 index 000000000..f7ef95ef8 --- /dev/null +++ b/fastdeploy/entrypoints/cli/run_batch.py @@ -0,0 +1,65 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/run_batch.py + +from __future__ import annotations + +import argparse +import asyncio +import importlib.metadata + +from fastdeploy.entrypoints.cli.types import CLISubcommand +from fastdeploy.utils import ( + FASTDEPLOY_SUBCMD_PARSER_EPILOG, + FlexibleArgumentParser, + show_filtered_argument_or_group_from_help, +) + + +class RunBatchSubcommand(CLISubcommand): + """The `run-batch` subcommand for FastDeploy CLI.""" + + name = "run-batch" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + from fastdeploy.entrypoints.openai.run_batch import main as run_batch_main + + print("FastDeploy batch processing API version", importlib.metadata.version("fastdeploy-gpu")) + print(args) + asyncio.run(run_batch_main(args)) + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + from fastdeploy.entrypoints.openai.run_batch import make_arg_parser + + run_batch_parser = subparsers.add_parser( + "run-batch", + help="Run batch prompts and write results to file.", + description=( + "Run batch prompts using FastDeploy's OpenAI-compatible API.\n" + "Supports local or HTTP input/output files." + ), + usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model ", + ) + run_batch_parser = make_arg_parser(run_batch_parser) + show_filtered_argument_or_group_from_help(run_batch_parser, ["run-batch"]) + run_batch_parser.epilog = FASTDEPLOY_SUBCMD_PARSER_EPILOG + return run_batch_parser + + +def cmd_init() -> list[CLISubcommand]: + return [RunBatchSubcommand()] diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 81a41f016..0e6802e76 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -720,3 +720,71 @@ class ControlSchedulerRequest(BaseModel): reset: Optional[bool] = False load_shards_num: Optional[int] = None reallocate_shard: Optional[bool] = False + + +from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator + +BatchRequestInputBody = ChatCompletionRequest + + +class BatchRequestInput(BaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: BatchRequestInputBody + + @field_validator("body", mode="before") + @classmethod + def check_type_for_url(cls, value: Any, info: ValidationInfo): + # Use url to disambiguate models + url: str = info.data["url"] + if url == "/v1/chat/completions": + if isinstance(value, dict): + return value + return ChatCompletionRequest.model_validate(value) + return value + + +class BatchResponseData(BaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[ChatCompletionResponse] = None + + +class BatchRequestOutput(BaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] diff --git a/fastdeploy/entrypoints/openai/run_batch.py b/fastdeploy/entrypoints/openai/run_batch.py new file mode 100644 index 000000000..2cfd8cb0a --- /dev/null +++ b/fastdeploy/entrypoints/openai/run_batch.py @@ -0,0 +1,546 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/run_batch.py + +import argparse +import asyncio +import os +import tempfile +import uuid +from argparse import Namespace +from collections.abc import Awaitable +from http import HTTPStatus +from io import StringIO +from multiprocessing import current_process +from typing import Callable, List, Optional, Tuple + +import aiohttp +import zmq +from tqdm import tqdm + +from fastdeploy.engine.args_utils import EngineArgs +from fastdeploy.engine.engine import LLMEngine +from fastdeploy.entrypoints.chat_utils import load_chat_template +from fastdeploy.entrypoints.engine_client import EngineClient +from fastdeploy.entrypoints.openai.protocol import ( + BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + ErrorResponse, +) +from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels +from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager +from fastdeploy.utils import ( + FlexibleArgumentParser, + api_server_logger, + console_logger, + retrive_model_from_server, +) + +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" +llm_engine = None +engine_client = None + + +def make_arg_parser(parser: FlexibleArgumentParser): + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help="The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.", + ) + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.", + ) + parser.add_argument( + "--output-tmp-dir", + type=str, + default=None, + help="The directory to store the output file before uploading it " "to the output URL.", + ) + parser.add_argument( + "--max-waiting-time", + default=-1, + type=int, + help="max waiting time for connection, if set value -1 means no waiting time limit", + ) + parser.add_argument("--port", default=8000, type=int, help="port to the http server") + # parser.add_argument("--host", default="0.0.0.0", type=str, help="host to the http server") + parser.add_argument("--workers", default=None, type=int, help="number of workers") + parser.add_argument("--max-concurrency", default=512, type=int, help="max concurrency") + parser.add_argument( + "--enable-mm-output", action="store_true", help="Enable 'multimodal_content' field in response output. " + ) + parser = EngineArgs.add_cli_args(parser) + return parser + + +def parse_args(): + parser = FlexibleArgumentParser(description="FastDeploy OpenAI-Compatible batch runner.") + args = make_arg_parser(parser).parse_args() + return args + + +def init_engine(args: argparse.Namespace): + """ + load engine + """ + global llm_engine + if llm_engine is not None: + return llm_engine + + api_server_logger.info(f"FastDeploy LLM API server starting... {os.getpid()}") + engine_args = EngineArgs.from_cli_args(args) + engine = LLMEngine.from_engine_args(engine_args) + if not engine.start(api_server_pid=os.getpid()): + api_server_logger.error("Failed to initialize FastDeploy LLM engine, service exit now!") + return None + + llm_engine = engine + return engine + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._completed = 0 + self._pbar: Optional[tqdm] = None + self._last_log_count = 0 + + def submitted(self): + self._total += 1 + + def completed(self): + self._completed += 1 + if self._pbar: + self._pbar.update() + + if self._total > 0: + log_interval = min(100, max(self._total // 10, 1)) + if self._completed - self._last_log_count >= log_interval: + console_logger.info(f"Progress: {self._completed}/{self._total} requests completed") + self._last_log_count = self._completed + + def pbar(self) -> tqdm: + self._pbar = tqdm( + total=self._total, + unit="req", + desc="Running batch", + mininterval=10, + bar_format=_BAR_FORMAT, + ) + return self._pbar + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, encoding="utf-8") as f: + return f.read() + + +async def write_local_file(output_path: str, batch_outputs: list[BatchRequestOutput]) -> None: + """ + Write the responses to a local file. + output_path: The path to write the responses to. + batch_outputs: The list of batch outputs to write. + """ + # We should make this async, but as long as run_batch runs as a + # standalone program, blocking the event loop won't effect performance. + with open(output_path, "w", encoding="utf-8") as f: + for o in batch_outputs: + print(o.model_dump_json(), file=f) + + +async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None: + """ + Upload a local file to a URL. + output_url: The URL to upload the file to. + data_or_file: Either the data to upload or the path to the file to upload. + from_file: If True, data_or_file is the path to the file to upload. + """ + # Timeout is a common issue when uploading large files. + # We retry max_retries times before giving up. + max_retries = 5 + # Number of seconds to wait before retrying. + delay = 5 + + for attempt in range(1, max_retries + 1): + try: + # We increase the timeout to 1000 seconds to allow + # for large files (default is 300). + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=1000)) as session: + if from_file: + with open(data_or_file, "rb") as file: + async with session.put(output_url, data=file) as response: + if response.status != 200: + raise Exception( + f"Failed to upload file.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) + else: + async with session.put(output_url, data=data_or_file) as response: + if response.status != 200: + raise Exception( + f"Failed to upload data.\n" + f"Status: {response.status}\n" + f"Response: {response.text()}" + ) + + except Exception as e: + if attempt < max_retries: + console_logger.error( + "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501 + attempt, + e, + delay, + ) + await asyncio.sleep(delay) + else: + raise Exception( + f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501 + ) from e + + +async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str) -> None: + """ + Write batch_outputs to a file or upload to a URL. + path_or_url: The path or URL to write batch_outputs to. + batch_outputs: The list of batch outputs to write. + output_tmp_dir: The directory to store the output file before uploading it + to the output URL. + """ + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + if output_tmp_dir is None: + console_logger.info("Writing outputs to memory buffer") + output_buffer = StringIO() + for o in batch_outputs: + print(o.model_dump_json(), file=output_buffer) + output_buffer.seek(0) + console_logger.info("Uploading outputs to %s", path_or_url) + await upload_data( + path_or_url, + output_buffer.read().strip().encode("utf-8"), + from_file=False, + ) + else: + # Write responses to a temporary file and then upload it to the URL. + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=output_tmp_dir, + prefix="tmp_batch_output_", + suffix=".jsonl", + ) as f: + console_logger.info("Writing outputs to temporary local file %s", f.name) + await write_local_file(f.name, batch_outputs) + console_logger.info("Uploading outputs to %s", path_or_url) + await upload_data(path_or_url, f.name, from_file=True) + else: + console_logger.info("Writing outputs to local file %s", path_or_url) + await write_local_file(path_or_url, batch_outputs) + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def make_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + batch_output = BatchRequestOutput( + id=f"fastdeploy-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=HTTPStatus.BAD_REQUEST, + request_id=f"fastdeploy-batch-{random_uuid()}", + ), + error=error_msg, + ) + return batch_output + + +async def make_async_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + return make_error_request_output(request, error_msg) + + +async def run_request( + serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker, + semaphore: asyncio.Semaphore, +) -> BatchRequestOutput: + async with semaphore: + try: + response = await serving_engine_func(request.body) + + if isinstance(response, ChatCompletionResponse): + batch_output = BatchRequestOutput( + id=f"fastdeploy-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=200, body=response, request_id=f"fastdeploy-batch-{random_uuid()}" + ), + error=None, + ) + elif isinstance(response, ErrorResponse): + batch_output = BatchRequestOutput( + id=f"fastdeploy-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData(status_code=400, request_id=f"fastdeploy-batch-{random_uuid()}"), + error=response, + ) + else: + batch_output = make_error_request_output(request, error_msg="Request must not be sent in stream mode") + + tracker.completed() + return batch_output + + except Exception as e: + console_logger.error(f"Request {request.custom_id} processing failed: {str(e)}") + tracker.completed() + return make_error_request_output(request, error_msg=f"Request processing failed: {str(e)}") + + +def determine_process_id() -> int: + """Determine the appropriate process ID.""" + if current_process().name != "MainProcess": + return os.getppid() + else: + return os.getpid() + + +def create_model_paths(args: Namespace) -> List[ModelPath]: + """Create model paths configuration.""" + if args.served_model_name is not None: + served_model_names = args.served_model_name + verification = True + else: + served_model_names = args.model + verification = False + + return [ModelPath(name=served_model_names, model_path=args.model, verification=verification)] + + +async def initialize_engine_client(args: Namespace, pid: int) -> EngineClient: + """Initialize and configure the engine client.""" + engine_client = EngineClient( + model_name_or_path=args.model, + tokenizer=args.tokenizer, + max_model_len=args.max_model_len, + tensor_parallel_size=args.tensor_parallel_size, + pid=pid, + port=int(args.engine_worker_queue_port[args.local_data_parallel_id]), + limit_mm_per_prompt=args.limit_mm_per_prompt, + mm_processor_kwargs=args.mm_processor_kwargs, + reasoning_parser=args.reasoning_parser, + data_parallel_size=args.data_parallel_size, + enable_logprob=args.enable_logprob, + workers=args.workers, + tool_parser=args.tool_call_parser, + ) + + await engine_client.connection_manager.initialize() + engine_client.create_zmq_client(model=pid, mode=zmq.PUSH) + engine_client.pid = pid + + return engine_client + + +def create_serving_handlers( + args: Namespace, engine_client: EngineClient, model_paths: List[ModelPath], chat_template: str, pid: int +) -> OpenAIServingChat: + """Create model and chat serving handlers.""" + model_handler = OpenAIServingModels( + model_paths, + args.max_model_len, + args.ips, + ) + + chat_handler = OpenAIServingChat( + engine_client, + model_handler, + pid, + args.ips, + args.max_waiting_time, + chat_template, + args.enable_mm_output, + args.tokenizer_base_url, + ) + + return chat_handler + + +async def setup_engine_and_handlers(args: Namespace) -> Tuple[EngineClient, OpenAIServingChat]: + """Setup engine client and all necessary handlers.""" + + if args.tokenizer is None: + args.tokenizer = args.model + + pid = determine_process_id() + console_logger.info(f"Process ID: {pid}") + + model_paths = create_model_paths(args) + chat_template = load_chat_template(args.chat_template, args.model) + + # Initialize engine client + engine_client = await initialize_engine_client(args, pid) + engine_client = engine_client + + # Create handlers + chat_handler = create_serving_handlers(args, engine_client, model_paths, chat_template, pid) + + # Update data processor if engine exists + if llm_engine is not None: + llm_engine.engine.data_processor = engine_client.data_processor + + return engine_client, chat_handler + + +async def run_batch( + args: argparse.Namespace, +) -> None: + + # Setup engine and handlers + engine_client, chat_handler = await setup_engine_and_handlers(args) + + concurrency = getattr(args, "max_concurrency", 512) + workers = getattr(args, "workers", 1) + max_concurrency = (concurrency + workers - 1) // workers + semaphore = asyncio.Semaphore(max_concurrency) + + console_logger.info(f"concurrency: {concurrency}, workers: {workers}, max_concurrency: {max_concurrency}") + + tracker = BatchProgressTracker() + console_logger.info("Reading batch from %s...", args.input_file) + + # Submit all requests in the file to the engine "concurrently". + response_futures: list[Awaitable[BatchRequestOutput]] = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + # Skip empty lines. + request_json = request_json.strip() + if not request_json: + continue + + request = BatchRequestInput.model_validate_json(request_json) + + # Determine the type of request and run it. + if request.url == "/v1/chat/completions": + chat_handler_fn = chat_handler.create_chat_completion if chat_handler is not None else None + if chat_handler_fn is None: + response_futures.append( + make_async_error_request_output( + request, + error_msg="The model does not support Chat Completions API", + ) + ) + continue + + response_futures.append(run_request(chat_handler_fn, request, tracker, semaphore)) + tracker.submitted() + else: + response_futures.append( + make_async_error_request_output( + request, + error_msg=f"URL {request.url} was used. " + "Supported endpoints: /v1/chat/completions" + "See fastdeploy/entrypoints/openai/api_server.py for supported " + "/v1/chat/completions versions.", + ) + ) + + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) + + success_count = sum(1 for r in responses if r.error is None) + error_count = len(responses) - success_count + console_logger.info(f"Batch processing completed: {success_count} success, {error_count} errors") + + await write_file(args.output_file, responses, args.output_tmp_dir) + console_logger.info("Results written to output file") + + +async def main(args: argparse.Namespace): + console_logger.info("Starting batch runner with args: %s", args) + try: + if args.workers is None: + args.workers = max(min(int(args.max_num_seqs // 32), 8), 1) + + args.model = retrive_model_from_server(args.model, args.revision) + + if args.tool_parser_plugin: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + if not init_engine(args): + return + + await run_batch(args) + except Exception as e: + print("Fatal error in main:") + print(e) + console_logger.error(f"Fatal error in main: {e}", exc_info=True) + raise + finally: + await cleanup_resources() + + +async def cleanup_resources() -> None: + """Clean up all resources during shutdown.""" + try: + # stop engine + if llm_engine is not None: + try: + llm_engine._exit_sub_services() + except Exception as e: + console_logger.error(f"Error stopping engine: {e}") + + # close client connections + if engine_client is not None: + try: + if hasattr(engine_client, "zmq_client"): + engine_client.zmq_client.close() + if hasattr(engine_client, "connection_manager"): + await engine_client.connection_manager.close() + except Exception as e: + console_logger.error(f"Error closing client connections: {e}") + + # garbage collect + import gc + + gc.collect() + print("run batch done") + + except Exception as e: + console_logger.error(f"Error during cleanup: {e}") + + +if __name__ == "__main__": + args = parse_args() + asyncio.run(main(args=args)) diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 5975f1a5c..310e88bc3 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -24,6 +24,7 @@ import os import random import re import socket +import subprocess import sys import tarfile import time @@ -57,6 +58,98 @@ from typing import Callable, Optional # Make sure enable_xxx equal to config.enable_xxx ARGS_CORRECTION_LIST = [["early_stop_config", "enable_early_stop"], ["graph_optimization_config", "use_cudagraph"]] +FASTDEPLOY_SUBCMD_PARSER_EPILOG = ( + "Tip: Use `fastdeploy [serve|run-batch|bench ] " + "--help=` to explore arguments from help.\n" + " - To view a argument group: --help=ModelConfig\n" + " - To view a single argument: --help=max-num-seqs\n" + " - To search by keyword: --help=max\n" + " - To list all groups: --help=listgroup\n" + " - To view help with pager: --help=page" +) + + +def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, subcommand_name: list[str]): + + # Only handle --help= for the current subcommand. + # Since subparser_init() runs for all subcommands during CLI setup, + # we skip processing if the subcommand name is not in sys.argv. + # sys.argv[0] is the program name. The subcommand follows. + # e.g., for `vllm bench latency`, + # sys.argv is `['vllm', 'bench', 'latency', ...]` + # and subcommand_name is "bench latency". + if len(sys.argv) <= len(subcommand_name) or sys.argv[1 : 1 + len(subcommand_name)] != subcommand_name: + return + + for arg in sys.argv: + if arg.startswith("--help="): + search_keyword = arg.split("=", 1)[1] + + # Enable paged view for full help + if search_keyword == "page": + help_text = parser.format_help() + _output_with_pager(help_text) + sys.exit(0) + + # List available groups + if search_keyword == "listgroup": + output_lines = ["\nAvailable argument groups:"] + for group in parser._action_groups: + if group.title and not group.title.startswith("positional arguments"): + output_lines.append(f" - {group.title}") + if group.description: + output_lines.append(" " + group.description.strip()) + output_lines.append("") + _output_with_pager("\n".join(output_lines)) + sys.exit(0) + + # For group search + formatter = parser._get_formatter() + for group in parser._action_groups: + if group.title and group.title.lower() == search_keyword.lower(): + formatter.start_section(group.title) + formatter.add_text(group.description) + formatter.add_arguments(group._group_actions) + formatter.end_section() + _output_with_pager(formatter.format_help()) + sys.exit(0) + + # For single arg + matched_actions = [] + + for group in parser._action_groups: + for action in group._group_actions: + # search option name + if any(search_keyword.lower() in opt.lower() for opt in action.option_strings): + matched_actions.append(action) + + if matched_actions: + header = f"\nParameters matching '{search_keyword}':\n" + formatter = parser._get_formatter() + formatter.add_arguments(matched_actions) + _output_with_pager(header + formatter.format_help()) + sys.exit(0) + + print(f"\nNo group or parameter matching '{search_keyword}'") + print("Tip: use `--help=listgroup` to view all groups.") + sys.exit(1) + + +def _output_with_pager(text: str): + """Output text using scrolling view if available and appropriate.""" + + pagers = ["less -R", "more"] + for pager_cmd in pagers: + try: + proc = subprocess.Popen(pager_cmd.split(), stdin=subprocess.PIPE, text=True) + proc.communicate(input=text) + return + except (subprocess.SubprocessError, OSError, FileNotFoundError): + continue + + # No pager worked, fall back to normal print + print(text) + class EngineError(Exception): """Base exception class for engine errors""" diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py new file mode 100644 index 000000000..f785797c6 --- /dev/null +++ b/tests/entrypoints/openai/test_run_batch.py @@ -0,0 +1,1333 @@ +import asyncio +import json +import os +import subprocess +import tempfile +import unittest +from http import HTTPStatus +from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch + +from tqdm import tqdm + +from fastdeploy.entrypoints.openai.protocol import ( + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + ErrorResponse, + UsageInfo, +) +from fastdeploy.entrypoints.openai.run_batch import ( + _BAR_FORMAT, + BatchProgressTracker, + ModelPath, + cleanup_resources, + create_model_paths, + create_serving_handlers, + determine_process_id, + init_engine, + initialize_engine_client, + main, + make_async_error_request_output, + make_error_request_output, + parse_args, + random_uuid, + read_file, + run_batch, + run_request, + setup_engine_and_handlers, + upload_data, + write_file, + write_local_file, +) + +INPUT_BATCH = """ +{"custom_id": "req-00001", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Can you write a short poem? (id=1)"}], "temperature": 0.7, "max_tokens": 200}} +{"custom_id": "req-00002", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "What can you do? (id=2)"}], "temperature": 0.7, "max_tokens": 200}} +{"custom_id": "req-00003", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Hello, who are you? (id=3)"}], "temperature": 0.7, "max_tokens": 200}} +""" + +INVALID_INPUT_BATCH = """ +{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}} +""" + +BATCH_RESPONSE = """ +{"id":"fastdeploy-7fcc30e2e4334fca806c4d01ee7ac4ab","custom_id":"req-00001","response":{"status_code":200,"request_id":"fastdeploy-batch-5f4017beded84b15aa3a8b0f1fce154c","body":{"id":"chatcmpl-33b09ae5-a8f1-40ad-9110-efa2b381eac9","object":"chat.completion","created":1758698637,"model":"/root/paddlejob/zhaolei36/ernie-4_5-0_3b-bf16-paddle","choices":[{"index":0,"message":{"role":"assistant","content":"In a sunlit meadow where dreams bloom,\\nA gentle breeze carries the breeze,\\nThe leaves rustle like ancient letters,\\nAnd in the sky, a song of hope and love.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":19,"total_tokens":60,"completion_tokens":41,"prompt_tokens_details":{"cached_tokens":0}}}},"error":null} +{"id":"fastdeploy-bf549849df2145598ae1758ba260f784","custom_id":"req-00002","response":{"status_code":200,"request_id":"fastdeploy-batch-81223f12fdc345efbfe85114ced10a1d","body":{"id":"chatcmpl-9479e36c-1542-45ff-b364-1dc6d34be9e7","object":"chat.completion","created":1758698637,"model":"/root/paddlejob/zhaolei36/ernie-4_5-0_3b-bf16-paddle","choices":[{"index":0,"message":{"role":"assistant","content":"Based on the given text, here are some possible actions you can take:\\n\\n1. **Read the question**: To understand what you can do, you can read the question (id=2) and analyze its requirements or constraints.\\n2. **Identify the keywords**: Look for specific keywords or phrases that describe what you can do. For example, if the question mentions \\"coding,\\" you can focus on coding skills or platforms.\\n3. **Brainstorm ideas**: You can think creatively about different ways to perform the action. For example, you could brainstorm different methods of communication, data analysis, or problem-solving.\\n4. **Explain your action**: If you have knowledge or skills in a particular area, you can explain how you would use those skills to achieve the desired outcome.\\n5. **Ask for help**: If you need assistance, you can ask for help from a friend, teacher, or mentor.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"total_tokens":211,"completion_tokens":194,"prompt_tokens_details":{"cached_tokens":0}}}},"error":null} +""" + + +class TestArgParser(unittest.TestCase): + """测试参数解析相关函数""" + + @patch("fastdeploy.entrypoints.openai.run_batch.FlexibleArgumentParser") + @patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs") + def test_make_arg_parser(self, mock_engine_args, mock_parser_class): + """测试make_arg_parser函数""" + from fastdeploy.entrypoints.openai.run_batch import make_arg_parser + + mock_parser = Mock() + mock_parser_class.return_value = mock_parser + + # 让EngineArgs.add_cli_args返回parser本身 + mock_engine_args.add_cli_args.return_value = mock_parser + + result = make_arg_parser(mock_parser) + + # 验证参数被正确添加 + mock_parser.add_argument.assert_any_call("-i", "--input-file", required=True, type=str, help=unittest.mock.ANY) + mock_parser.add_argument.assert_any_call( + "-o", "--output-file", required=True, type=str, help=unittest.mock.ANY + ) + mock_parser.add_argument.assert_any_call("--output-tmp-dir", type=str, default=None, help=unittest.mock.ANY) + mock_engine_args.add_cli_args.assert_called_once_with(mock_parser) + # 现在应该返回parser而不是EngineArgs.add_cli_args的返回值 + self.assertEqual(result, mock_parser) + + @patch("fastdeploy.entrypoints.openai.run_batch.FlexibleArgumentParser") + @patch("fastdeploy.entrypoints.openai.run_batch.make_arg_parser") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + def test_parse_args(self, mock_logger, mock_make_parser, mock_parser_class): + """测试parse_args函数""" + mock_parser = Mock() + mock_args = Mock() + mock_parser_class.return_value = mock_parser + mock_parser.parse_args.return_value = mock_args + mock_make_parser.return_value = mock_parser + + result = parse_args() + + mock_parser_class.assert_called_once_with(description="FastDeploy OpenAI-Compatible batch runner.") + mock_make_parser.assert_called_once_with(mock_parser) + mock_parser.parse_args.assert_called_once() + self.assertEqual(result, mock_args) + + +class TestEngineInitialization(unittest.TestCase): + """测试引擎初始化相关函数""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + @patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine") + @patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs") + @patch("fastdeploy.entrypoints.openai.run_batch.api_server_logger") + @patch("fastdeploy.entrypoints.openai.run_batch.os") + def test_init_engine_success(self, mock_os, mock_logger, mock_engine_args, mock_llm_engine): + """测试init_engine成功初始化""" + + with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None): + mock_args = Mock() + mock_engine_args.from_cli_args.return_value = Mock() + mock_engine = Mock() + mock_engine.start.return_value = True + mock_llm_engine.from_engine_args.return_value = mock_engine + mock_os.getpid.return_value = 123 + + result = init_engine(mock_args) + + mock_engine_args.from_cli_args.assert_called_with(mock_args) + mock_llm_engine.from_engine_args.assert_called_with(mock_engine_args.from_cli_args.return_value) + mock_engine.start.assert_called_with(api_server_pid=123) + mock_logger.info.assert_called_with("FastDeploy LLM API server starting... 123") + self.assertEqual(result, mock_engine) + + @patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine") + @patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs") + @patch("fastdeploy.entrypoints.openai.run_batch.api_server_logger") + def test_init_engine_failure(self, mock_logger, mock_engine_args, mock_llm_engine): + """测试init_engine初始化失败""" + with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None): + mock_args = Mock() + mock_engine_args.from_cli_args.return_value = Mock() + mock_engine = Mock() + mock_engine.start.return_value = False + mock_llm_engine.from_engine_args.return_value = mock_engine + + result = init_engine(mock_args) + + mock_logger.error.assert_called_with("Failed to initialize FastDeploy LLM engine, service exit now!") + self.assertIsNone(result) + + @patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine") + def test_init_engine_already_initialized(self, mock_llm_engine): + """测试init_engine已经初始化的情况""" + existing_engine = Mock() + with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", existing_engine): + mock_args = Mock() + result = init_engine(mock_args) + + mock_llm_engine.from_engine_args.assert_not_called() + self.assertEqual(result, existing_engine) + + @patch("fastdeploy.entrypoints.openai.run_batch.EngineClient") + async def test_initialize_engine_client(self, mock_engine_client): + """测试初始化引擎客户端""" + mock_args = Mock() + mock_args.model = "test-model" + mock_args.tokenizer = "test-tokenizer" + mock_args.max_model_len = 1000 + mock_args.tensor_parallel_size = 1 + mock_args.engine_worker_queue_port = [8000] + mock_args.local_data_parallel_id = 0 + mock_args.limit_mm_per_prompt = None + mock_args.mm_processor_kwargs = {} + mock_args.reasoning_parser = None + mock_args.data_parallel_size = 1 + mock_args.enable_logprob = False + mock_args.workers = 1 + mock_args.tool_call_parser = None + + mock_client_instance = AsyncMock() + mock_engine_client.return_value = mock_client_instance + + pid = 123 + result = await initialize_engine_client(mock_args, pid) + + # 验证EngineClient被正确初始化 + mock_engine_client.assert_called_once() + mock_client_instance.connection_manager.initialize.assert_called_once() + mock_client_instance.create_zmq_client.assert_called_once_with(model=pid, mode=unittest.mock.ANY) + self.assertEqual(mock_client_instance.pid, pid) + self.assertEqual(result, mock_client_instance) + + @patch("fastdeploy.entrypoints.openai.run_batch.OpenAIServingModels") + @patch("fastdeploy.entrypoints.openai.run_batch.OpenAIServingChat") + def test_create_serving_handlers(self, mock_chat_handler, mock_model_handler): + """测试创建服务处理器""" + mock_args = Mock() + mock_args.max_model_len = 1000 + mock_args.ips = "127.0.0.1" + mock_args.max_waiting_time = 60 + mock_args.enable_mm_output = False + mock_args.tokenizer_base_url = None + + mock_engine_client = Mock() + mock_model_paths = [Mock(spec=ModelPath)] + chat_template = "test_template" + pid = 123 + + mock_model_instance = Mock() + mock_model_handler.return_value = mock_model_instance + + mock_chat_instance = Mock() + mock_chat_handler.return_value = mock_chat_instance + + result = create_serving_handlers(mock_args, mock_engine_client, mock_model_paths, chat_template, pid) + + # 验证处理器被正确创建 + mock_model_handler.assert_called_once_with(mock_model_paths, mock_args.max_model_len, mock_args.ips) + mock_chat_handler.assert_called_once_with( + mock_engine_client, + mock_model_instance, + pid, + mock_args.ips, + mock_args.max_waiting_time, + chat_template, + mock_args.enable_mm_output, + mock_args.tokenizer_base_url, + ) + self.assertEqual(result, mock_chat_instance) + + @patch("fastdeploy.entrypoints.openai.run_batch.determine_process_id") + @patch("fastdeploy.entrypoints.openai.run_batch.create_model_paths") + @patch("fastdeploy.entrypoints.openai.run_batch.load_chat_template") + @patch("fastdeploy.entrypoints.openai.run_batch.initialize_engine_client") + @patch("fastdeploy.entrypoints.openai.run_batch.create_serving_handlers") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_setup_engine_and_handlers( + self, + mock_logger, + mock_create_handlers, + mock_init_engine, + mock_load_template, + mock_create_paths, + mock_determine_pid, + ): + """测试设置引擎和处理器""" + mock_args = Mock() + mock_args.tokenizer = None + mock_args.model = "test-model" + mock_args.chat_template = "template_name" + + # 设置mock返回值 + mock_determine_pid.return_value = 123 + mock_create_paths.return_value = [Mock(spec=ModelPath)] + mock_load_template.return_value = "loaded_template" + mock_engine_client = AsyncMock() + mock_init_engine.return_value = mock_engine_client + mock_chat_handler = Mock() + mock_create_handlers.return_value = mock_chat_handler + + # 模拟全局llm_engine存在的情况 + mock_llm_engine = Mock() + mock_llm_engine.engine = Mock() + mock_llm_engine.engine.data_processor = None + + with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_llm_engine): + result = await setup_engine_and_handlers(mock_args) + + # 验证调用链 + mock_determine_pid.assert_called_once() + mock_logger.info.assert_called_with("Process ID: 123") + self.assertEqual(mock_args.tokenizer, "test-model") # 验证tokenizer被设置 + mock_create_paths.assert_called_with(mock_args) + mock_load_template.assert_called_with("template_name", "test-model") + mock_init_engine.assert_called_with(mock_args, 123) + mock_create_handlers.assert_called_with( + mock_args, mock_engine_client, mock_create_paths.return_value, "loaded_template", 123 + ) + + # 验证数据处理器被更新 + self.assertEqual(mock_llm_engine.engine.data_processor, mock_engine_client.data_processor) + + self.assertEqual(result, (mock_engine_client, mock_chat_handler)) + + @patch("fastdeploy.entrypoints.openai.run_batch.determine_process_id") + @patch("fastdeploy.entrypoints.openai.run_batch.create_model_paths") + @patch("fastdeploy.entrypoints.openai.run_batch.load_chat_template") + @patch("fastdeploy.entrypoints.openai.run_batch.initialize_engine_client") + @patch("fastdeploy.entrypoints.openai.run_batch.create_serving_handlers") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_setup_engine_and_handlers_no_llm_engine( + self, + mock_logger, + mock_create_handlers, + mock_init_engine, + mock_load_template, + mock_create_paths, + mock_determine_pid, + ): + """测试设置引擎和处理器(没有全局llm_engine的情况)""" + mock_args = Mock() + mock_args.tokenizer = None + mock_args.model = "test-model" + mock_args.chat_template = "template_name" + + # 设置mock返回值 + mock_determine_pid.return_value = 123 + mock_create_paths.return_value = [Mock(spec=ModelPath)] + mock_load_template.return_value = "loaded_template" + mock_engine_client = AsyncMock() + mock_init_engine.return_value = mock_engine_client + mock_chat_handler = Mock() + mock_create_handlers.return_value = mock_chat_handler + + # 模拟全局llm_engine不存在的情况 + with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None): + result = await setup_engine_and_handlers(mock_args) + + # 验证调用链 + mock_determine_pid.assert_called_once() + mock_logger.info.assert_called_with("Process ID: 123") + self.assertEqual(mock_args.tokenizer, "test-model") + mock_create_paths.assert_called_with(mock_args) + mock_load_template.assert_called_with("template_name", "test-model") + mock_init_engine.assert_called_with(mock_args, 123) + mock_create_handlers.assert_called_with( + mock_args, mock_engine_client, mock_create_paths.return_value, "loaded_template", 123 + ) + + self.assertEqual(result, (mock_engine_client, mock_chat_handler)) + + +class TestBatchProcessing(unittest.TestCase): + """测试批处理相关函数""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + @patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers") + @patch("fastdeploy.entrypoints.openai.run_batch.read_file") + @patch("fastdeploy.entrypoints.openai.run_batch.run_request") + @patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output") + @patch("fastdeploy.entrypoints.openai.run_batch.write_file") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_batch_success( + self, mock_logger, mock_write_file, mock_make_error, mock_run_request, mock_read_file, mock_setup + ): + """测试成功运行批处理""" + # 模拟参数 + mock_args = Mock() + mock_args.input_file = "input.jsonl" + mock_args.output_file = "output.jsonl" + mock_args.output_tmp_dir = "/tmp" + mock_args.max_concurrency = 512 + mock_args.workers = 2 + + # 模拟设置返回 + mock_engine_client = Mock() + mock_chat_handler = Mock() + mock_chat_handler.create_chat_completion = Mock() + mock_setup.return_value = (mock_engine_client, mock_chat_handler) + + # 模拟输入文件内容 + mock_read_file.return_value = ( + '{"url": "/v1/chat/completions", "custom_id": "1"}\n\n{"url": "/v1/chat/completions", "custom_id": "2"}' + ) + + # 模拟请求处理结果 + mock_response1 = Mock(error=None) + mock_response2 = Mock(error=None) + + # 模拟异步操作 + future1 = asyncio.Future() + future1.set_result(mock_response1) + future2 = asyncio.Future() + future2.set_result(mock_response2) + + mock_run_request.side_effect = [future1, future2] + + mock_make_error.return_value = asyncio.Future() + mock_make_error.return_value.set_result(Mock()) + + await run_batch(mock_args) + + # 验证日志记录 + mock_logger.info.assert_any_call("concurrency: 512, workers: 2, max_concurrency: 256") + mock_logger.info.assert_any_call("Reading batch from input.jsonl...") + mock_logger.info.assert_any_call("Batch processing completed: 2 success, 0 errors") + + # 验证文件写入 + mock_write_file.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers") + @patch("fastdeploy.entrypoints.openai.run_batch.read_file") + @patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output") + @patch("fastdeploy.entrypoints.openai.run_batch.write_file") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_batch_unsupported_endpoint( + self, mock_logger, mock_write_file, mock_make_error, mock_read_file, mock_setup + ): + """测试不支持的端点""" + mock_args = Mock() + mock_args.input_file = "input.jsonl" + mock_args.output_file = "output.jsonl" + mock_args.output_tmp_dir = "/tmp" + mock_args.max_concurrency = 512 + mock_args.workers = 1 + + mock_setup.return_value = (Mock(), Mock()) + + # 模拟不支持的URL + mock_read_file.return_value = '{"url": "/v1/unsupported", "custom_id": "1"}' + + mock_make_error.return_value = asyncio.Future() + mock_make_error.return_value.set_result(Mock()) + + await run_batch(mock_args) + + # 验证错误处理被调用 + mock_make_error.assert_called_once() + mock_logger.info.assert_any_call("Batch processing completed: 0 success, 1 errors") + + @patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers") + @patch("fastdeploy.entrypoints.openai.run_batch.read_file") + @patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output") + @patch("fastdeploy.entrypoints.openai.run_batch.write_file") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_batch_no_chat_handler_for_chat_completions( + self, mock_logger, mock_write_file, mock_make_error, mock_read_file, mock_setup + ): + """测试chat_handler为None时处理chat请求""" + mock_args = Mock() + mock_args.input_file = "input.jsonl" + mock_args.output_file = "output.jsonl" + mock_args.output_tmp_dir = "/tmp" + mock_args.max_concurrency = 512 + mock_args.workers = 1 + + # 返回None作为chat_handler + mock_setup.return_value = (Mock(), None) + + mock_read_file.return_value = '{"url": "/v1/chat/completions", "custom_id": "1"}' + + mock_make_error.return_value = asyncio.Future() + mock_error_output = Mock() + mock_make_error.return_value.set_result(mock_error_output) + + await run_batch(mock_args) + + # 验证错误处理被调用 + mock_make_error.assert_called_once_with( + unittest.mock.ANY, error_msg="The model does not support Chat Completions API" + ) + mock_logger.info.assert_any_call("Batch processing completed: 0 success, 1 errors") + + @patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server") + @patch("fastdeploy.entrypoints.openai.run_batch.ToolParserManager") + @patch("fastdeploy.entrypoints.openai.run_batch.init_engine") + @patch("fastdeploy.entrypoints.openai.run_batch.run_batch") + @patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_main_success( + self, mock_logger, mock_cleanup, mock_run_batch, mock_init_engine, mock_tool_parser, mock_retrieve_model + ): + """测试主函数成功执行""" + mock_args = Mock() + mock_args.workers = None + mock_args.max_num_seqs = 64 + mock_args.model = "test-model" + mock_args.revision = "main" + mock_args.tool_parser_plugin = None + + mock_retrieve_model.return_value = "retrieved-model" + mock_init_engine.return_value = True + + await main(mock_args) + + # 验证参数处理 + self.assertEqual(mock_args.workers, 2) + self.assertEqual(mock_args.model, "retrieved-model") + mock_retrieve_model.assert_called_with("test-model", "main") + mock_init_engine.assert_called_with(mock_args) + mock_run_batch.assert_called_with(mock_args) + mock_cleanup.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server") + @patch("fastdeploy.entrypoints.openai.run_batch.ToolParserManager") + @patch("fastdeploy.entrypoints.openai.run_batch.init_engine") + @patch("fastdeploy.entrypoints.openai.run_batch.run_batch") + @patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_main_with_tool_parser_plugin( + self, mock_logger, mock_cleanup, mock_run_batch, mock_init_engine, mock_tool_parser, mock_retrieve_model + ): + """测试主函数使用tool_parser_plugin""" + mock_args = Mock() + mock_args.workers = 1 + mock_args.max_num_seqs = 32 + mock_args.model = "test-model" + mock_args.revision = "main" + mock_args.tool_parser_plugin = "test_plugin" + + mock_retrieve_model.return_value = "retrieved-model" + mock_init_engine.return_value = True + + await main(mock_args) + + # 验证工具解析器插件被导入 + mock_tool_parser.import_tool_parser.assert_called_once_with("test_plugin") + mock_init_engine.assert_called_with(mock_args) + mock_run_batch.assert_called_with(mock_args) + mock_cleanup.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server") + @patch("fastdeploy.entrypoints.openai.run_batch.init_engine") + @patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_main_init_engine_fails(self, mock_logger, mock_cleanup, mock_init_engine, mock_retrieve_model): + """测试初始化引擎失败的情况""" + mock_args = Mock() + mock_args.workers = None + mock_args.max_num_seqs = 64 + mock_args.model = "test-model" + mock_args.revision = "main" + mock_args.tool_parser_plugin = None + + mock_retrieve_model.return_value = "retrieved-model" + mock_init_engine.return_value = False # 初始化失败 + + await main(mock_args) + + # 验证没有运行批处理 + mock_init_engine.assert_called_with(mock_args) + mock_cleanup.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_cleanup_resources_success(self, mock_logger): + """测试资源清理成功""" + # 模拟全局变量 + with ( + patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None), + patch("fastdeploy.entrypoints.openai.run_batch.engine_client", None), + ): + await cleanup_resources() + + # 验证日志记录 + mock_logger.error.assert_not_called() + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_cleanup_resources_with_errors(self, mock_logger): + """测试资源清理时出现错误""" + # 模拟有问题的引擎和客户端 + mock_engine = Mock() + mock_engine._exit_sub_services = Mock(side_effect=Exception("Engine error")) + + mock_client = Mock() + mock_client.zmq_client = Mock() + mock_client.zmq_client.close = Mock(side_effect=Exception("ZMQ error")) + mock_client.connection_manager = AsyncMock() + mock_client.connection_manager.close = AsyncMock(side_effect=Exception("Connection error")) + + with ( + patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine), + patch("fastdeploy.entrypoints.openai.run_batch.engine_client", mock_client), + ): + await cleanup_resources() + + # 验证错误被记录但不会抛出 + self.assertEqual(mock_logger.error.call_count, 3) + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_cleanup_resources_partial_errors(self, mock_logger): + """测试资源清理时部分组件出错""" + # 模拟只有引擎有问题的情况 + mock_engine = Mock() + mock_engine._exit_sub_services = Mock(side_effect=Exception("Engine error")) + + with ( + patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine), + patch("fastdeploy.entrypoints.openai.run_batch.engine_client", None), + ): + await cleanup_resources() + + # 验证只有引擎错误被记录 + mock_logger.error.assert_called_once() + mock_logger.error.assert_called_with("Error stopping engine: Engine error") + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + @patch("gc.collect") + async def test_cleanup_resources_with_gc(self, mock_gc, mock_logger): + """测试资源清理包括垃圾回收""" + # 模拟有引擎和客户端的情况 + mock_engine = Mock() + mock_engine._exit_sub_services = Mock() + + mock_client = Mock() + mock_client.zmq_client = Mock() + mock_client.zmq_client.close = Mock() + mock_client.connection_manager = AsyncMock() + mock_client.connection_manager.close = AsyncMock() + + with ( + patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine), + patch("fastdeploy.entrypoints.openai.run_batch.engine_client", mock_client), + ): + await cleanup_resources() + + # 验证垃圾回收被调用 + mock_gc.assert_called_once() + mock_logger.error.assert_not_called() + + +class TestRunRequest(unittest.TestCase): + """测试run_request函数""" + + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + @patch("fastdeploy.entrypoints.openai.run_batch.random_uuid") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_request_success_chat_completion(self, mock_logger, mock_random_uuid): + """测试成功返回ChatCompletionResponse的情况""" + mock_random_uuid.side_effect = ["id1", "req1"] + + # 模拟成功的响应 + mock_response = Mock(spec=ChatCompletionResponse) + mock_engine = AsyncMock(return_value=mock_response) + mock_request = Mock() + mock_request.custom_id = "test-id" + mock_request.body = "test-body" + mock_tracker = Mock() + mock_semaphore = AsyncMock() + + result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore) + + # 验证结果 + self.assertEqual(result.custom_id, "test-id") + self.assertEqual(result.response.status_code, 200) + self.assertEqual(result.response.body, mock_response) + self.assertIsNone(result.error) + mock_tracker.completed.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.random_uuid") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_request_error_response(self, mock_logger, mock_random_uuid): + """测试返回ErrorResponse的情况""" + mock_random_uuid.side_effect = ["id2", "req2"] + + # 模拟错误响应 + mock_error = Mock(spec=ErrorResponse) + mock_engine = AsyncMock(return_value=mock_error) + mock_request = Mock() + mock_request.custom_id = "error-id" + mock_tracker = Mock() + mock_semaphore = AsyncMock() + + result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore) + + # 验证错误结果 + self.assertEqual(result.response.status_code, 400) + self.assertEqual(result.error, mock_error) + mock_tracker.completed.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_request_stream_mode_error(self, mock_logger, mock_make_error): + """测试流模式错误情况""" + # 模拟非ChatCompletionResponse和ErrorResponse的响应 + mock_engine = AsyncMock(return_value="invalid_response") + mock_request = Mock() + mock_tracker = Mock() + mock_semaphore = AsyncMock() + mock_error_output = Mock() + mock_make_error.return_value = mock_error_output + + result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore) + + # 验证调用了错误处理函数 + mock_make_error.assert_called_once_with(mock_request, "Request must not be sent in stream mode") + self.assertEqual(result, mock_error_output) + mock_tracker.completed.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output") + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + async def test_run_request_exception(self, mock_logger, mock_make_error): + """测试异常情况""" + # 模拟抛出异常 + mock_engine = AsyncMock(side_effect=Exception("Test error")) + mock_request = Mock() + mock_request.custom_id = "exception-id" + mock_tracker = Mock() + mock_semaphore = AsyncMock() + mock_error_output = Mock() + mock_make_error.return_value = mock_error_output + + result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore) + + # 验证错误日志和错误处理 + mock_logger.error.assert_called_once() + mock_make_error.assert_called_once_with(mock_request, "Request processing failed: Test error") + self.assertEqual(result, mock_error_output) + mock_tracker.completed.assert_called_once() + + +class TestDetermineProcessId(unittest.TestCase): + """测试determine_process_id函数""" + + @patch("multiprocessing.current_process") + @patch("os.getppid") + @patch("os.getpid") + def test_determine_process_id_main_process(self, mock_getpid, mock_getppid, mock_current_process): + """测试主进程情况""" + mock_current_process.return_value.name = "MainProcess" + mock_getpid.return_value = 123 + + result = determine_process_id() + + self.assertEqual(result, 123) + mock_getpid.assert_called_once() + mock_getppid.assert_not_called() + + @patch("multiprocessing.current_process") + @patch("os.getppid") + @patch("os.getpid") + def test_determine_process_id_child_process(self, mock_getpid, mock_getppid, mock_current_process): + """测试子进程情况""" + mock_current_process.return_value.name = "Process-1" + mock_getppid.return_value = 456 + + determine_process_id() + + mock_getpid.assert_called_once() + + +class TestCreateModelPaths(unittest.TestCase): + """测试create_model_paths函数""" + + def test_create_model_paths_with_served_model_name(self): + """测试提供served_model_name的情况""" + mock_args = Mock() + mock_args.served_model_name = "custom-model-name" + mock_args.model = "path/to/model" + + result = create_model_paths(mock_args) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "custom-model-name") + self.assertEqual(result[0].model_path, "path/to/model") + self.assertTrue(result[0].verification) + + def test_create_model_paths_without_served_model_name(self): + """测试不提供served_model_name的情况""" + mock_args = Mock() + mock_args.served_model_name = None + mock_args.model = "path/to/model" + + result = create_model_paths(mock_args) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "path/to/model") + self.assertEqual(result[0].model_path, "path/to/model") + self.assertFalse(result[0].verification) + + +class TestErrorRequestOutput(unittest.TestCase): + """测试错误请求输出生成函数""" + + def setUp(self): + # 设置异步测试循环 + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + @patch("fastdeploy.entrypoints.openai.run_batch.random_uuid") + def test_make_error_request_output_basic(self, mock_random_uuid): + """测试基本功能""" + mock_random_uuid.side_effect = ["req123", "batch456"] + + mock_request = Mock() + mock_request.custom_id = "test-id" + + result = make_error_request_output(mock_request, "Test error") + + # 验证基本属性 + self.assertEqual(result.id, "fastdeploy-req123") + self.assertEqual(result.custom_id, "test-id") + self.assertEqual(result.error, "Test error") + self.assertEqual(result.response.status_code, HTTPStatus.BAD_REQUEST) + self.assertEqual(result.response.request_id, "fastdeploy-batch-batch456") + + @patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output") + async def test_make_async_error_request_output(self, mock_make_error): + """测试异步版本""" + expected_output = Mock() + mock_make_error.return_value = expected_output + + mock_request = Mock() + mock_request.custom_id = "async-test" + + result = await make_async_error_request_output(mock_request, "Async error") + + self.assertEqual(result, expected_output) + mock_make_error.assert_called_once_with(mock_request, "Async error") + + +class TestFileOperations(unittest.TestCase): + """测试文件操作相关函数""" + + def setUp(self): + # 设置异步测试循环 + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + @patch("aiohttp.ClientSession") + async def test_read_file_http(self, mock_session): + """测试从HTTP URL读取文件""" + # 模拟响应 + mock_resp = AsyncMock() + mock_resp.text = AsyncMock(return_value="HTTP content") + mock_session.return_value.__aenter__.return_value.get.return_value.__aenter__.return_value = mock_resp + + result = await read_file("https://example.com/file.txt") + + self.assertEqual(result, "HTTP content") + mock_session.assert_called_once() + + def create_batch_outputs_from_jsonl(self, jsonl_text): + """从 JSONL 文本创建 BatchRequestOutput 对象列表""" + batch_outputs = [] + lines = jsonl_text.strip().split("\n") + + for line in lines: + if line.strip(): + data = json.loads(line) + + # 解析 response 部分 + response_data = data["response"] + body_data = response_data["body"] + + # 创建 ChatMessage 对象 + message_data = body_data["choices"][0]["message"] + chat_message = ChatMessage( + role=message_data["role"], + content=message_data["content"], + multimodal_content=message_data["multimodal_content"], + reasoning_content=message_data["reasoning_content"], + tool_calls=message_data["tool_calls"], + prompt_token_ids=message_data["prompt_token_ids"], + completion_token_ids=message_data["completion_token_ids"], + text_after_process=message_data["text_after_process"], + raw_prediction=message_data["raw_prediction"], + prompt_tokens=message_data["prompt_tokens"], + completion_tokens=message_data["completion_tokens"], + ) + + # 创建 ChatCompletionResponseChoice 对象 + choice_data = body_data["choices"][0] + choice = ChatCompletionResponseChoice( + index=choice_data["index"], + message=chat_message, + logprobs=choice_data["logprobs"], + finish_reason=choice_data["finish_reason"], + ) + + # 创建 UsageInfo 对象 + usage_data = body_data["usage"] + usage_info = UsageInfo( + prompt_tokens=usage_data["prompt_tokens"], + total_tokens=usage_data["total_tokens"], + completion_tokens=usage_data["completion_tokens"], + prompt_tokens_details=usage_data.get("prompt_tokens_details"), + ) + + # 创建 ChatCompletionResponse 对象 + chat_completion_response = ChatCompletionResponse( + id=body_data["id"], + object=body_data["object"], + created=body_data["created"], + model=body_data["model"], + choices=[choice], + usage=usage_info, + ) + + # 创建 BatchResponseData 对象 + batch_response_data = BatchResponseData( + status_code=response_data["status_code"], + request_id=response_data["request_id"], + body=chat_completion_response, + ) + + # 创建 BatchRequestOutput 对象 + batch_output = BatchRequestOutput( + id=data["id"], custom_id=data["custom_id"], response=batch_response_data, error=data["error"] + ) + batch_outputs.append(batch_output) + + return batch_outputs + + def test_write_local_file_basic(self): + """测试基础功能:写入文件并验证内容""" + # 创建测试数据 + batch_outputs = self.create_batch_outputs_from_jsonl(BATCH_RESPONSE) + + # 创建临时文件 + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file: + temp_path = temp_file.name + + try: + # 异步调用被测函数 + async def run_test(): + await write_local_file(temp_path, batch_outputs) + + self.loop.run_until_complete(run_test()) + + # 验证文件存在 + self.assertTrue(os.path.exists(temp_path)) + + # 验证文件不为空 + self.assertGreater(os.path.getsize(temp_path), 0) + + # 读取并验证文件内容 + with open(temp_path, "r", encoding="utf-8") as f: + written_lines = f.read().strip().split("\n") + + # 验证行数匹配 + self.assertEqual(len(written_lines), 2) + + # 验证每行都是有效的 JSON + for i, line in enumerate(written_lines): + data = json.loads(line) + self.assertIn("id", data) + self.assertIn("custom_id", data) + self.assertIn("response", data) + self.assertIn("error", data) + + # 验证关键字段 + self.assertEqual(data["custom_id"], f"req-0000{i+1}") + self.assertEqual(data["response"]["status_code"], 200) + self.assertIn("body", data["response"]) + self.assertIn("choices", data["response"]["body"]) + + print("✓ 基础功能测试通过") + + finally: + # 清理临时文件 + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_write_local_file_content_integrity(self): + """测试内容完整性:验证写入的内容与原始数据一致""" + # 创建测试数据 + batch_outputs = self.create_batch_outputs_from_jsonl(BATCH_RESPONSE) + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file: + temp_path = temp_file.name + + try: + # 异步调用被测函数 + async def run_test(): + await write_local_file(temp_path, batch_outputs) + + self.loop.run_until_complete(run_test()) + + # 读取写入的文件内容 + with open(temp_path, "r", encoding="utf-8") as f: + written_content = f.read().strip() + + # 解析原始数据 + original_lines = BATCH_RESPONSE.strip().split("\n") + written_lines = written_content.split("\n") + + # 验证行数一致 + self.assertEqual(len(original_lines), len(written_lines)) + + # 验证每行的关键字段一致 + for i, (orig_line, written_line) in enumerate(zip(original_lines, written_lines)): + orig_data = json.loads(orig_line) + written_data = json.loads(written_line) + + # 比较关键标识字段 + self.assertEqual(orig_data["id"], written_data["id"]) + self.assertEqual(orig_data["custom_id"], written_data["custom_id"]) + self.assertEqual(orig_data["response"]["status_code"], written_data["response"]["status_code"]) + + # 比较响应内容 + orig_content = orig_data["response"]["body"]["choices"][0]["message"]["content"] + written_content = written_data["response"]["body"]["choices"][0]["message"]["content"] + # 内容应该一致 + self.assertEqual(orig_content, written_content) + + print("✓ 内容完整性测试通过") + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_write_local_file_empty_list(self): + """测试空列表处理""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file: + temp_path = temp_file.name + + try: + # 异步调用函数写入空列表 + async def run_test(): + await write_local_file(temp_path, []) + + self.loop.run_until_complete(run_test()) + + # 验证文件存在但为空 + self.assertTrue(os.path.exists(temp_path)) + + with open(temp_path, "r", encoding="utf-8") as f: + content = f.read() + + self.assertEqual(content, "") + print("✓ 空列表处理测试通过") + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + @patch("builtins.open", new_callable=mock_open, read_data="Local content") + async def test_read_file_local(self, mock_file): + """测试从本地文件读取""" + result = await read_file("/local/path/file.txt") + + self.assertEqual(result, "Local content") + mock_file.assert_called_once_with("/local/path/file.txt", encoding="utf-8") + + @patch("builtins.open", new_callable=mock_open) + async def test_write_local_file(self, mock_file): + """测试写入本地文件""" + # 创建模拟的batch outputs + mock_outputs = [ + Mock(spec=BatchRequestOutput, model_dump_json=Mock(return_value='{"id": 1}')), + Mock(spec=BatchRequestOutput, model_dump_json=Mock(return_value='{"id": 2}')), + ] + + await write_local_file("/output/path.json", mock_outputs) + + mock_file.assert_called_once_with("/output/path.json", "w", encoding="utf-8") + + # 检查写入调用 + handle = mock_file() + expected_calls = [unittest.mock.call.write('{"id": 1}\n'), unittest.mock.call.write('{"id": 2}\n')] + handle.write.assert_has_calls(expected_calls) + + @patch("aiohttp.ClientSession") + async def test_upload_data_success(self, mock_session): + """测试成功上传数据""" + mock_resp = Mock(status=200, text=Mock(return_value="OK")) + mock_session.return_value.__aenter__.return_value.put.return_value.__aenter__.return_value = mock_resp + + # 测试从文件上传 + with patch("builtins.open", mock_open(read_data=b"file content")): + await upload_data("https://example.com/upload", "/path/to/file", from_file=True) + + # 测试直接上传数据 + await upload_data("https://example.com/upload", "raw data", from_file=False) + + self.assertEqual(mock_session.call_count, 2) + + @patch("aiohttp.ClientSession") + @patch("asyncio.sleep", new_callable=AsyncMock) + async def test_upload_data_retry(self, mock_sleep, mock_session): + """测试上传失败重试逻辑""" + # 模拟前两次失败,第三次成功 + mock_resp_fail = Mock(status=500, text=Mock(return_value="Server Error")) + mock_resp_success = Mock(status=200, text=Mock(return_value="OK")) + + mock_session.return_value.__aenter__.return_value.put.side_effect = [ + Exception("First failure"), + mock_resp_fail, + mock_resp_success, + ] + + # 这次应该成功,经过两次重试 + with patch("builtins.open", mock_open(read_data=b"content")): + await upload_data("https://example.com/upload", "/path/to/file", from_file=True) + + # 检查重试次数 + self.assertEqual(mock_sleep.call_count, 2) + self.assertEqual(mock_session.return_value.__aenter__.return_value.put.call_count, 3) + + @patch("aiohttp.ClientSession") + async def test_upload_data_failure(self, mock_session): + """测试上传最终失败""" + mock_session.return_value.__aenter__.return_value.put.side_effect = Exception("Persistent failure") + + with patch("builtins.open", mock_open(read_data=b"content")): + with self.assertRaises(Exception) as context: + await upload_data("https://example.com/upload", "/path/to/file", from_file=True) + + self.assertIn("Failed to upload data", str(context.exception)) + + @patch("fastdeploy.entrypoints.openai.run_batch.upload_data") + @patch("fastdeploy.entrypoints.openai.run_batch.write_local_file") + async def test_write_file_http_with_buffer(self, mock_write_local, mock_upload): + """测试HTTP输出写入到内存缓冲区""" + mock_outputs = [Mock(spec=BatchRequestOutput)] + + await write_file("https://example.com/output", mock_outputs, output_tmp_dir=None) + + # 应该调用upload_data,而不是write_local_file + mock_upload.assert_called_once() + mock_write_local.assert_not_called() + + @patch("fastdeploy.entrypoints.openai.run_batch.upload_data") + @patch("tempfile.NamedTemporaryFile") + @patch("fastdeploy.entrypoints.openai.run_batch.write_local_file") + async def test_write_file_http_with_tempfile(self, mock_write_local, mock_tempfile, mock_upload): + """测试HTTP输出写入到临时文件""" + # 模拟临时文件 + mock_file = Mock() + mock_file.name = "/tmp/tempfile.json" + mock_tempfile.return_value.__enter__.return_value = mock_file + + mock_outputs = [Mock(spec=BatchRequestOutput)] + + await write_file("https://example.com/output", mock_outputs, output_tmp_dir="/tmp") + + mock_tempfile.assert_called_once() + mock_write_local.assert_called_once_with(mock_file.name, mock_outputs) + mock_upload.assert_called_once_with("https://example.com/output", mock_file.name, from_file=True) + + @patch("fastdeploy.entrypoints.openai.run_batch.write_local_file") + async def test_write_file_local(self, mock_write_local): + """测试本地文件输出""" + mock_outputs = [Mock(spec=BatchRequestOutput)] + + await write_file("/local/output.json", mock_outputs, output_tmp_dir="/tmp") + + mock_write_local.assert_called_once_with("/local/output.json", mock_outputs) + + +class TestUtilityFunctions(unittest.TestCase): + """测试工具函数""" + + def test_random_uuid(self): + """测试生成随机UUID""" + uuid1 = random_uuid() + uuid2 = random_uuid() + + self.assertEqual(len(uuid1), 32) + self.assertTrue(all(c in "0123456789abcdef" for c in uuid1)) + + self.assertNotEqual(uuid1, uuid2) + + +class TestBatchProgressTracker(unittest.TestCase): + + def test_submitted_increments_total(self): + tracker = BatchProgressTracker() + self.assertEqual(tracker._total, 0) + tracker.submitted() + self.assertEqual(tracker._total, 1) + tracker.submitted() + self.assertEqual(tracker._total, 2) + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + def test_completed_increments_completed_and_logs(self, mock_logger): + tracker = BatchProgressTracker() + tracker._total = 20 + + # 调用 10 次 -> 应该触发一次日志 (log_interval=2) + for _ in range(10): + tracker.completed() + + self.assertEqual(tracker._completed, 10) + mock_logger.info.assert_called() # 至少被调用一次 + args, _ = mock_logger.info.call_args + self.assertIn("Progress: 10/20", args[0]) + + @patch("fastdeploy.entrypoints.openai.run_batch.tqdm") + def test_completed_updates_pbar(self, mock_tqdm): + mock_pbar = MagicMock() + mock_tqdm.return_value = mock_pbar + + tracker = BatchProgressTracker() + tracker._total = 5 + tracker.pbar() # 初始化 pbar + + tracker.completed() + mock_pbar.update.assert_called_once() + + @patch("fastdeploy.entrypoints.openai.run_batch.tqdm") + def test_pbar_returns_tqdm(self, mock_tqdm): + mock_pbar = MagicMock(spec=tqdm) + mock_tqdm.return_value = mock_pbar + + tracker = BatchProgressTracker() + tracker._total = 3 + result = tracker.pbar() + + self.assertIs(result, mock_pbar) + mock_tqdm.assert_called_once_with( + total=3, + unit="req", + desc="Running batch", + mininterval=10, + bar_format=_BAR_FORMAT, + ) + + +class TestBatchProgressTrackerExtended(unittest.TestCase): + """扩展的BatchProgressTracker测试""" + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + def test_completed_with_pbar_no_log(self, mock_logger): + """测试有进度条时的completed方法,不触发日志记录""" + tracker = BatchProgressTracker() + tracker._total = 100 # 设置较大的总数,使得第一次完成不会触发日志 + tracker._pbar = Mock() + + tracker.completed() # 完成1个,1/100=1%,不会触发日志记录 + + tracker._pbar.update.assert_called_once() + mock_logger.info.assert_not_called() # 不应该记录日志 + + @patch("fastdeploy.entrypoints.openai.run_batch.console_logger") + def test_completed_log_interval(self, mock_logger): + """测试日志间隔""" + tracker = BatchProgressTracker() + tracker._total = 100 + tracker._last_log_count = 0 + + # 触发日志记录(每10个记录一次) + for i in range(1, 21): + tracker.completed() + if i % 10 == 0: + mock_logger.info.assert_called_with(f"Progress: {i}/100 requests completed") + + +class TestFastDeployBatch(unittest.TestCase): + """测试 FastDeploy 批处理功能的 unittest 测试类""" + + def setUp(self): + """每个测试方法执行前的准备工作""" + self.model_path = "baidu/ERNIE-4.5-0.3B-PT" + self.base_command = ["fastdeploy", "run-batch"] + self.run_batch_command = ["python", "fastdeploy/entrypoints/openai/run_batch.py"] + + def run_fastdeploy_command(self, input_content, port=None): + """运行 FastDeploy 命令的辅助方法""" + if port is None: + port = "1231" + + with tempfile.NamedTemporaryFile("w") as input_file, tempfile.NamedTemporaryFile("r") as output_file: + + input_file.write(input_content) + input_file.flush() + + param = [ + "-i", + input_file.name, + "-o", + output_file.name, + "--model", + self.model_path, + "--cache-queue-port", + port, + "--tensor-parallel-size", + "1", + "--quantization", + "wint4", + "--max-model-len", + "4192", + "--max-num-seqs", + "64", + "--load-choices", + "default_v1", + "--engine-worker-queue-port", + "3672", + ] + + # command = self.base_command + param + run_batch_command = self.run_batch_command + param + + proc = subprocess.Popen(run_batch_command) + proc.communicate() + return_code = proc.wait() + + # 读取输出文件内容 + output_file.seek(0) + contents = output_file.read() + + return return_code, contents, proc + + def test_completions(self): + """测试正常的批量chat请求""" + return_code, contents, proc = self.run_fastdeploy_command(INPUT_BATCH, port="2235") + + self.assertEqual(return_code, 0, f"进程返回非零码: {return_code}, 进程信息: {proc}") + + # 验证每行输出都符合 OpenAI API 格式 + lines = contents.strip().split("\n") + for line in lines: + if line: # 跳过空行 + # 验证应该抛出异常如果 schema 错误 + try: + BatchRequestOutput.model_validate_json(line) + except Exception as e: + self.fail(f"输出格式验证失败: {e}\n行内容: {line}") + + def test_vaild_input(self): + """测试输入数据格式的正确性""" + return_code, contents, proc = self.run_fastdeploy_command(INVALID_INPUT_BATCH) + + self.assertNotEqual(return_code, 0, f"进程返回非零码: {return_code}, 进程信息: {proc}") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/entrypoints/openai/test_run_batch_proto.py b/tests/entrypoints/openai/test_run_batch_proto.py new file mode 100644 index 000000000..36311be88 --- /dev/null +++ b/tests/entrypoints/openai/test_run_batch_proto.py @@ -0,0 +1,89 @@ +import unittest + +from pydantic import ValidationError + +from fastdeploy.entrypoints.openai.protocol import ( + BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatMessage, + UsageInfo, +) + + +class TestBatchRequestModels(unittest.TestCase): + + def test_batch_request_input_with_dict_body(self): + body_dict = { + "messages": [{"role": "user", "content": "hi"}], + "model": "default", + } + obj = BatchRequestInput( + custom_id="test", + method="POST", + url="/v1/chat/completions", + body={"messages": [{"role": "user", "content": "hi"}]}, + ) + self.assertIsInstance(obj.body, ChatCompletionRequest) + self.assertEqual(obj.body.model_dump()["messages"], body_dict["messages"]) + self.assertEqual(obj.body.model, "default") + + def test_batch_request_input_with_model_body(self): + body_model = ChatCompletionRequest(messages=[{"role": "user", "content": "Hi"}], model="gpt-test") + obj = BatchRequestInput( + custom_id="456", + method="POST", + url="/v1/chat/completions", + body=body_model, + ) + self.assertIsInstance(obj.body, ChatCompletionRequest) + self.assertEqual(obj.body.model, "gpt-test") + + def test_batch_request_input_with_other_url(self): + obj = BatchRequestInput( + custom_id="789", + method="POST", + url="/v1/other/endpoint", + body={"messages": [{"role": "user", "content": "hi"}]}, + ) + self.assertIsInstance(obj.body, ChatCompletionRequest) + self.assertEqual(obj.body.messages[0]["content"], "hi") + + def test_batch_response_data(self): + usage = UsageInfo(prompt_tokens=1, total_tokens=2, completion_tokens=1) + chat_msg = ChatMessage(role="assistant", content="ok") + choice = ChatCompletionResponseChoice(index=0, message=chat_msg, finish_reason="stop") + + resp = ChatCompletionResponse(id="r1", model="gpt-test", choices=[choice], usage=usage) + + data = BatchResponseData( + status_code=200, + request_id="req-1", + body=resp, + ) + self.assertEqual(data.status_code, 200) + self.assertEqual(data.body.id, "r1") + self.assertEqual(data.body.choices[0].message.content, "ok") + + def test_batch_request_output(self): + response = BatchResponseData(status_code=200, request_id="req-2", body=None) + out = BatchRequestOutput(id="out-1", custom_id="cid-1", response=response, error=None) + self.assertEqual(out.id, "out-1") + self.assertEqual(out.response.request_id, "req-2") + self.assertIsNone(out.error) + + def test_invalid_batch_request_input(self): + with self.assertRaises(ValidationError): + BatchRequestInput( + custom_id="id", + method="POST", + url="/v1/chat/completions", + body={"model": "gpt-test"}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/openai/test_run_batch_subcommand.py b/tests/entrypoints/openai/test_run_batch_subcommand.py new file mode 100644 index 000000000..e3f7db812 --- /dev/null +++ b/tests/entrypoints/openai/test_run_batch_subcommand.py @@ -0,0 +1,143 @@ +""" +Unit tests for RunBatchSubcommand class. +""" + +import argparse +import unittest +from unittest.mock import Mock, patch + + +class TestRunBatchSubcommand(unittest.TestCase): + """Test cases for RunBatchSubcommand class.""" + + def test_name(self): + """Test subcommand name.""" + + # Create a mock class that mimics RunBatchSubcommand + class MockRunBatchSubcommand: + name = "run-batch" + + subcommand = MockRunBatchSubcommand() + self.assertEqual(subcommand.name, "run-batch") + + @patch("builtins.print") + @patch("asyncio.run") + def test_cmd(self, mock_asyncio, mock_print): + """Test cmd method.""" + # Mock the main function + mock_main = Mock() + + # Create a mock cmd function that simulates the real behavior + def mock_cmd(args): + # Simulate importlib.metadata.version call + version = "1.0.0" # Mock version + print("FastDeploy batch processing API version", version) + print(args) + mock_asyncio(mock_main(args)) + + args = argparse.Namespace(input="test.jsonl") + mock_cmd(args) + + # Verify calls + mock_print.assert_any_call("FastDeploy batch processing API version", "1.0.0") + mock_print.assert_any_call(args) + mock_asyncio.assert_called_once() + + def test_subparser_init(self): + """Test subparser initialization.""" + # Mock all the dependencies + mock_subparsers = Mock() + mock_parser = Mock() + mock_subparsers.add_parser.return_value = mock_parser + + # Mock the subparser_init behavior + def mock_subparser_init(subparsers): + parser = subparsers.add_parser( + "run-batch", + help="Run batch prompts and write results to file.", + description=( + "Run batch prompts using FastDeploy's OpenAI-compatible API.\n" + "Supports local or HTTP input/output files." + ), + usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model ", + ) + parser.epilog = "FASTDEPLOY_SUBCMD_PARSER_EPILOG" + return parser + + result = mock_subparser_init(mock_subparsers) + + # Verify the parser was added + mock_subparsers.add_parser.assert_called_once_with( + "run-batch", + help="Run batch prompts and write results to file.", + description=( + "Run batch prompts using FastDeploy's OpenAI-compatible API.\n" + "Supports local or HTTP input/output files." + ), + usage="FastDeploy run-batch -i INPUT.jsonl -o OUTPUT.jsonl --model ", + ) + self.assertEqual(result.epilog, "FASTDEPLOY_SUBCMD_PARSER_EPILOG") + + +class TestCmdInit(unittest.TestCase): + """Test cmd_init function.""" + + def test_cmd_init(self): + """Test cmd_init returns RunBatchSubcommand.""" + + # Mock the cmd_init function behavior + def mock_cmd_init(): + class MockRunBatchSubcommand: + name = "run-batch" + + @staticmethod + def cmd(args): + pass + + def subparser_init(self, subparsers): + pass + + return [MockRunBatchSubcommand()] + + result = mock_cmd_init() + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "run-batch") + self.assertTrue(hasattr(result[0], "cmd")) + self.assertTrue(hasattr(result[0], "subparser_init")) + + +class TestIntegration(unittest.TestCase): + """Integration tests without actual imports.""" + + def test_workflow(self): + """Test the complete workflow with mocks.""" + + # Create mock objects that simulate the real workflow + class MockSubcommand: + name = "run-batch" + + @staticmethod + def cmd(args): + return f"Executed with {args}" + + def subparser_init(self, subparsers): + return "parser_created" + + # Test subcommand creation + subcommand = MockSubcommand() + self.assertEqual(subcommand.name, "run-batch") + + # Test command execution + args = argparse.Namespace(input="test.jsonl", output="result.jsonl") + result = subcommand.cmd(args) + self.assertIn("test.jsonl", str(result)) + + # Test parser initialization + mock_subparsers = Mock() + parser_result = subcommand.subparser_init(mock_subparsers) + self.assertEqual(parser_result, "parser_created") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/utils/test_run_batch_tools.py b/tests/utils/test_run_batch_tools.py new file mode 100644 index 000000000..a257f9528 --- /dev/null +++ b/tests/utils/test_run_batch_tools.py @@ -0,0 +1,107 @@ +import argparse +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.utils import ( + FASTDEPLOY_SUBCMD_PARSER_EPILOG, + show_filtered_argument_or_group_from_help, +) + + +class TestHelpFilter(unittest.TestCase): + def setUp(self): + self.parser = argparse.ArgumentParser(prog="fastdeploy", epilog=FASTDEPLOY_SUBCMD_PARSER_EPILOG) + self.subcommand = ["bench"] + self.mock_sys_argv = ["fastdeploy"] + self.subcommand + + # Add test groups and arguments + self.model_group = self.parser.add_argument_group("ModelConfig", "Model configuration parameters") + self.model_group.add_argument("--model-path", help="Path to model") + self.model_group.add_argument("--max-num-seqs", help="Max sequences") + + self.train_group = self.parser.add_argument_group("TrainingConfig", "Training parameters") + self.train_group.add_argument("--epochs", help="Training epochs") + + @patch("sys.argv", ["fastdeploy", "bench", "--help=page"]) + @patch("subprocess.Popen") + def test_page_help(self, mock_popen): + mock_proc = MagicMock() + mock_proc.communicate.return_value = (None, None) + mock_popen.return_value = mock_proc + + # Expect SystemExit with code 0 + with self.assertRaises(SystemExit) as cm: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + + self.assertEqual(cm.exception.code, 0) + mock_popen.assert_called_once() + + @patch("sys.argv", ["fastdeploy", "bench", "--help=listgroup"]) + @patch("fastdeploy.utils._output_with_pager") + def test_list_groups(self, mock_output): + # Expect SystemExit with code 0 + with self.assertRaises(SystemExit) as cm: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + + self.assertEqual(cm.exception.code, 0) + # Verify that the output function was called + mock_output.assert_called_once() + # Check that the output contains expected groups + output_text = mock_output.call_args[0][0] + self.assertIn("ModelConfig", output_text) + self.assertIn("TrainingConfig", output_text) + + @patch("sys.argv", ["fastdeploy", "bench", "--help=ModelConfig"]) + @patch("fastdeploy.utils._output_with_pager") + def test_group_search(self, mock_output): + # Expect SystemExit with code 0 + with self.assertRaises(SystemExit) as cm: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + + self.assertEqual(cm.exception.code, 0) + # Verify that the output function was called + mock_output.assert_called_once() + # Check that the output contains expected content + output_text = mock_output.call_args[0][0] + self.assertIn("ModelConfig", output_text) + self.assertIn("--model-path", output_text) + + @patch("sys.argv", ["fastdeploy", "bench", "--help=max"]) + @patch("fastdeploy.utils._output_with_pager") + def test_arg_search(self, mock_output): + # Expect SystemExit with code 0 + with self.assertRaises(SystemExit) as cm: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + + self.assertEqual(cm.exception.code, 0) + # Verify that the output function was called + mock_output.assert_called_once() + # Check that the output contains expected content + output_text = mock_output.call_args[0][0] + self.assertIn("--max-num-seqs", output_text) + self.assertNotIn("--epochs", output_text) + + @patch("sys.argv", ["fastdeploy", "bench", "--help=nonexistent"]) + @patch("builtins.print") + def test_no_match(self, mock_print): + # Expect SystemExit with code 1 (error case) + with self.assertRaises(SystemExit) as cm: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + + self.assertEqual(cm.exception.code, 1) + # Check that error message was printed + mock_print.assert_called() + call_args = [call.args[0] for call in mock_print.call_args_list] + self.assertTrue(any("No group or parameter matching" in arg for arg in call_args)) + + @patch("sys.argv", ["fastdeploy", "othercmd"]) + def test_wrong_subcommand(self): + # This should not raise SystemExit, just return normally + try: + show_filtered_argument_or_group_from_help(self.parser, self.subcommand) + except SystemExit: + self.fail("Function should not exit when subcommand doesn't match") + + +if __name__ == "__main__": + unittest.main()