Files
FastDeploy/tests/entrypoints/openai/test_run_batch.py
xiaolei373 55124f8491 Add cli run batch (#4237)
* feat(log):add_request_and_response_log

* [cli] add run batch cli

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-26 14:27:25 +08:00

1334 lines
56 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import subprocess
import tempfile
import unittest
from http import HTTPStatus
from unittest.mock import AsyncMock, MagicMock, Mock, mock_open, patch
from tqdm import tqdm
from fastdeploy.entrypoints.openai.protocol import (
BatchRequestOutput,
BatchResponseData,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
ErrorResponse,
UsageInfo,
)
from fastdeploy.entrypoints.openai.run_batch import (
_BAR_FORMAT,
BatchProgressTracker,
ModelPath,
cleanup_resources,
create_model_paths,
create_serving_handlers,
determine_process_id,
init_engine,
initialize_engine_client,
main,
make_async_error_request_output,
make_error_request_output,
parse_args,
random_uuid,
read_file,
run_batch,
run_request,
setup_engine_and_handlers,
upload_data,
write_file,
write_local_file,
)
INPUT_BATCH = """
{"custom_id": "req-00001", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Can you write a short poem? (id=1)"}], "temperature": 0.7, "max_tokens": 200}}
{"custom_id": "req-00002", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "What can you do? (id=2)"}], "temperature": 0.7, "max_tokens": 200}}
{"custom_id": "req-00003", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "user", "content": "Hello, who are you? (id=3)"}], "temperature": 0.7, "max_tokens": 200}}
"""
INVALID_INPUT_BATCH = """
{"invalid_field": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 1000}}
"""
BATCH_RESPONSE = """
{"id":"fastdeploy-7fcc30e2e4334fca806c4d01ee7ac4ab","custom_id":"req-00001","response":{"status_code":200,"request_id":"fastdeploy-batch-5f4017beded84b15aa3a8b0f1fce154c","body":{"id":"chatcmpl-33b09ae5-a8f1-40ad-9110-efa2b381eac9","object":"chat.completion","created":1758698637,"model":"/root/paddlejob/zhaolei36/ernie-4_5-0_3b-bf16-paddle","choices":[{"index":0,"message":{"role":"assistant","content":"In a sunlit meadow where dreams bloom,\\nA gentle breeze carries the breeze,\\nThe leaves rustle like ancient letters,\\nAnd in the sky, a song of hope and love.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":19,"total_tokens":60,"completion_tokens":41,"prompt_tokens_details":{"cached_tokens":0}}}},"error":null}
{"id":"fastdeploy-bf549849df2145598ae1758ba260f784","custom_id":"req-00002","response":{"status_code":200,"request_id":"fastdeploy-batch-81223f12fdc345efbfe85114ced10a1d","body":{"id":"chatcmpl-9479e36c-1542-45ff-b364-1dc6d34be9e7","object":"chat.completion","created":1758698637,"model":"/root/paddlejob/zhaolei36/ernie-4_5-0_3b-bf16-paddle","choices":[{"index":0,"message":{"role":"assistant","content":"Based on the given text, here are some possible actions you can take:\\n\\n1. **Read the question**: To understand what you can do, you can read the question (id=2) and analyze its requirements or constraints.\\n2. **Identify the keywords**: Look for specific keywords or phrases that describe what you can do. For example, if the question mentions \\"coding,\\" you can focus on coding skills or platforms.\\n3. **Brainstorm ideas**: You can think creatively about different ways to perform the action. For example, you could brainstorm different methods of communication, data analysis, or problem-solving.\\n4. **Explain your action**: If you have knowledge or skills in a particular area, you can explain how you would use those skills to achieve the desired outcome.\\n5. **Ask for help**: If you need assistance, you can ask for help from a friend, teacher, or mentor.","multimodal_content":null,"reasoning_content":null,"tool_calls":null,"prompt_token_ids":null,"completion_token_ids":null,"text_after_process":null,"raw_prediction":null,"prompt_tokens":null,"completion_tokens":null},"logprobs":null,"finish_reason":"stop"}],"usage":{"prompt_tokens":17,"total_tokens":211,"completion_tokens":194,"prompt_tokens_details":{"cached_tokens":0}}}},"error":null}
"""
class TestArgParser(unittest.TestCase):
"""测试参数解析相关函数"""
@patch("fastdeploy.entrypoints.openai.run_batch.FlexibleArgumentParser")
@patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs")
def test_make_arg_parser(self, mock_engine_args, mock_parser_class):
"""测试make_arg_parser函数"""
from fastdeploy.entrypoints.openai.run_batch import make_arg_parser
mock_parser = Mock()
mock_parser_class.return_value = mock_parser
# 让EngineArgs.add_cli_args返回parser本身
mock_engine_args.add_cli_args.return_value = mock_parser
result = make_arg_parser(mock_parser)
# 验证参数被正确添加
mock_parser.add_argument.assert_any_call("-i", "--input-file", required=True, type=str, help=unittest.mock.ANY)
mock_parser.add_argument.assert_any_call(
"-o", "--output-file", required=True, type=str, help=unittest.mock.ANY
)
mock_parser.add_argument.assert_any_call("--output-tmp-dir", type=str, default=None, help=unittest.mock.ANY)
mock_engine_args.add_cli_args.assert_called_once_with(mock_parser)
# 现在应该返回parser而不是EngineArgs.add_cli_args的返回值
self.assertEqual(result, mock_parser)
@patch("fastdeploy.entrypoints.openai.run_batch.FlexibleArgumentParser")
@patch("fastdeploy.entrypoints.openai.run_batch.make_arg_parser")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
def test_parse_args(self, mock_logger, mock_make_parser, mock_parser_class):
"""测试parse_args函数"""
mock_parser = Mock()
mock_args = Mock()
mock_parser_class.return_value = mock_parser
mock_parser.parse_args.return_value = mock_args
mock_make_parser.return_value = mock_parser
result = parse_args()
mock_parser_class.assert_called_once_with(description="FastDeploy OpenAI-Compatible batch runner.")
mock_make_parser.assert_called_once_with(mock_parser)
mock_parser.parse_args.assert_called_once()
self.assertEqual(result, mock_args)
class TestEngineInitialization(unittest.TestCase):
"""测试引擎初始化相关函数"""
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine")
@patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs")
@patch("fastdeploy.entrypoints.openai.run_batch.api_server_logger")
@patch("fastdeploy.entrypoints.openai.run_batch.os")
def test_init_engine_success(self, mock_os, mock_logger, mock_engine_args, mock_llm_engine):
"""测试init_engine成功初始化"""
with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None):
mock_args = Mock()
mock_engine_args.from_cli_args.return_value = Mock()
mock_engine = Mock()
mock_engine.start.return_value = True
mock_llm_engine.from_engine_args.return_value = mock_engine
mock_os.getpid.return_value = 123
result = init_engine(mock_args)
mock_engine_args.from_cli_args.assert_called_with(mock_args)
mock_llm_engine.from_engine_args.assert_called_with(mock_engine_args.from_cli_args.return_value)
mock_engine.start.assert_called_with(api_server_pid=123)
mock_logger.info.assert_called_with("FastDeploy LLM API server starting... 123")
self.assertEqual(result, mock_engine)
@patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine")
@patch("fastdeploy.entrypoints.openai.run_batch.EngineArgs")
@patch("fastdeploy.entrypoints.openai.run_batch.api_server_logger")
def test_init_engine_failure(self, mock_logger, mock_engine_args, mock_llm_engine):
"""测试init_engine初始化失败"""
with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None):
mock_args = Mock()
mock_engine_args.from_cli_args.return_value = Mock()
mock_engine = Mock()
mock_engine.start.return_value = False
mock_llm_engine.from_engine_args.return_value = mock_engine
result = init_engine(mock_args)
mock_logger.error.assert_called_with("Failed to initialize FastDeploy LLM engine, service exit now!")
self.assertIsNone(result)
@patch("fastdeploy.entrypoints.openai.run_batch.LLMEngine")
def test_init_engine_already_initialized(self, mock_llm_engine):
"""测试init_engine已经初始化的情况"""
existing_engine = Mock()
with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", existing_engine):
mock_args = Mock()
result = init_engine(mock_args)
mock_llm_engine.from_engine_args.assert_not_called()
self.assertEqual(result, existing_engine)
@patch("fastdeploy.entrypoints.openai.run_batch.EngineClient")
async def test_initialize_engine_client(self, mock_engine_client):
"""测试初始化引擎客户端"""
mock_args = Mock()
mock_args.model = "test-model"
mock_args.tokenizer = "test-tokenizer"
mock_args.max_model_len = 1000
mock_args.tensor_parallel_size = 1
mock_args.engine_worker_queue_port = [8000]
mock_args.local_data_parallel_id = 0
mock_args.limit_mm_per_prompt = None
mock_args.mm_processor_kwargs = {}
mock_args.reasoning_parser = None
mock_args.data_parallel_size = 1
mock_args.enable_logprob = False
mock_args.workers = 1
mock_args.tool_call_parser = None
mock_client_instance = AsyncMock()
mock_engine_client.return_value = mock_client_instance
pid = 123
result = await initialize_engine_client(mock_args, pid)
# 验证EngineClient被正确初始化
mock_engine_client.assert_called_once()
mock_client_instance.connection_manager.initialize.assert_called_once()
mock_client_instance.create_zmq_client.assert_called_once_with(model=pid, mode=unittest.mock.ANY)
self.assertEqual(mock_client_instance.pid, pid)
self.assertEqual(result, mock_client_instance)
@patch("fastdeploy.entrypoints.openai.run_batch.OpenAIServingModels")
@patch("fastdeploy.entrypoints.openai.run_batch.OpenAIServingChat")
def test_create_serving_handlers(self, mock_chat_handler, mock_model_handler):
"""测试创建服务处理器"""
mock_args = Mock()
mock_args.max_model_len = 1000
mock_args.ips = "127.0.0.1"
mock_args.max_waiting_time = 60
mock_args.enable_mm_output = False
mock_args.tokenizer_base_url = None
mock_engine_client = Mock()
mock_model_paths = [Mock(spec=ModelPath)]
chat_template = "test_template"
pid = 123
mock_model_instance = Mock()
mock_model_handler.return_value = mock_model_instance
mock_chat_instance = Mock()
mock_chat_handler.return_value = mock_chat_instance
result = create_serving_handlers(mock_args, mock_engine_client, mock_model_paths, chat_template, pid)
# 验证处理器被正确创建
mock_model_handler.assert_called_once_with(mock_model_paths, mock_args.max_model_len, mock_args.ips)
mock_chat_handler.assert_called_once_with(
mock_engine_client,
mock_model_instance,
pid,
mock_args.ips,
mock_args.max_waiting_time,
chat_template,
mock_args.enable_mm_output,
mock_args.tokenizer_base_url,
)
self.assertEqual(result, mock_chat_instance)
@patch("fastdeploy.entrypoints.openai.run_batch.determine_process_id")
@patch("fastdeploy.entrypoints.openai.run_batch.create_model_paths")
@patch("fastdeploy.entrypoints.openai.run_batch.load_chat_template")
@patch("fastdeploy.entrypoints.openai.run_batch.initialize_engine_client")
@patch("fastdeploy.entrypoints.openai.run_batch.create_serving_handlers")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_setup_engine_and_handlers(
self,
mock_logger,
mock_create_handlers,
mock_init_engine,
mock_load_template,
mock_create_paths,
mock_determine_pid,
):
"""测试设置引擎和处理器"""
mock_args = Mock()
mock_args.tokenizer = None
mock_args.model = "test-model"
mock_args.chat_template = "template_name"
# 设置mock返回值
mock_determine_pid.return_value = 123
mock_create_paths.return_value = [Mock(spec=ModelPath)]
mock_load_template.return_value = "loaded_template"
mock_engine_client = AsyncMock()
mock_init_engine.return_value = mock_engine_client
mock_chat_handler = Mock()
mock_create_handlers.return_value = mock_chat_handler
# 模拟全局llm_engine存在的情况
mock_llm_engine = Mock()
mock_llm_engine.engine = Mock()
mock_llm_engine.engine.data_processor = None
with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_llm_engine):
result = await setup_engine_and_handlers(mock_args)
# 验证调用链
mock_determine_pid.assert_called_once()
mock_logger.info.assert_called_with("Process ID: 123")
self.assertEqual(mock_args.tokenizer, "test-model") # 验证tokenizer被设置
mock_create_paths.assert_called_with(mock_args)
mock_load_template.assert_called_with("template_name", "test-model")
mock_init_engine.assert_called_with(mock_args, 123)
mock_create_handlers.assert_called_with(
mock_args, mock_engine_client, mock_create_paths.return_value, "loaded_template", 123
)
# 验证数据处理器被更新
self.assertEqual(mock_llm_engine.engine.data_processor, mock_engine_client.data_processor)
self.assertEqual(result, (mock_engine_client, mock_chat_handler))
@patch("fastdeploy.entrypoints.openai.run_batch.determine_process_id")
@patch("fastdeploy.entrypoints.openai.run_batch.create_model_paths")
@patch("fastdeploy.entrypoints.openai.run_batch.load_chat_template")
@patch("fastdeploy.entrypoints.openai.run_batch.initialize_engine_client")
@patch("fastdeploy.entrypoints.openai.run_batch.create_serving_handlers")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_setup_engine_and_handlers_no_llm_engine(
self,
mock_logger,
mock_create_handlers,
mock_init_engine,
mock_load_template,
mock_create_paths,
mock_determine_pid,
):
"""测试设置引擎和处理器没有全局llm_engine的情况"""
mock_args = Mock()
mock_args.tokenizer = None
mock_args.model = "test-model"
mock_args.chat_template = "template_name"
# 设置mock返回值
mock_determine_pid.return_value = 123
mock_create_paths.return_value = [Mock(spec=ModelPath)]
mock_load_template.return_value = "loaded_template"
mock_engine_client = AsyncMock()
mock_init_engine.return_value = mock_engine_client
mock_chat_handler = Mock()
mock_create_handlers.return_value = mock_chat_handler
# 模拟全局llm_engine不存在的情况
with patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None):
result = await setup_engine_and_handlers(mock_args)
# 验证调用链
mock_determine_pid.assert_called_once()
mock_logger.info.assert_called_with("Process ID: 123")
self.assertEqual(mock_args.tokenizer, "test-model")
mock_create_paths.assert_called_with(mock_args)
mock_load_template.assert_called_with("template_name", "test-model")
mock_init_engine.assert_called_with(mock_args, 123)
mock_create_handlers.assert_called_with(
mock_args, mock_engine_client, mock_create_paths.return_value, "loaded_template", 123
)
self.assertEqual(result, (mock_engine_client, mock_chat_handler))
class TestBatchProcessing(unittest.TestCase):
"""测试批处理相关函数"""
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers")
@patch("fastdeploy.entrypoints.openai.run_batch.read_file")
@patch("fastdeploy.entrypoints.openai.run_batch.run_request")
@patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output")
@patch("fastdeploy.entrypoints.openai.run_batch.write_file")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_batch_success(
self, mock_logger, mock_write_file, mock_make_error, mock_run_request, mock_read_file, mock_setup
):
"""测试成功运行批处理"""
# 模拟参数
mock_args = Mock()
mock_args.input_file = "input.jsonl"
mock_args.output_file = "output.jsonl"
mock_args.output_tmp_dir = "/tmp"
mock_args.max_concurrency = 512
mock_args.workers = 2
# 模拟设置返回
mock_engine_client = Mock()
mock_chat_handler = Mock()
mock_chat_handler.create_chat_completion = Mock()
mock_setup.return_value = (mock_engine_client, mock_chat_handler)
# 模拟输入文件内容
mock_read_file.return_value = (
'{"url": "/v1/chat/completions", "custom_id": "1"}\n\n{"url": "/v1/chat/completions", "custom_id": "2"}'
)
# 模拟请求处理结果
mock_response1 = Mock(error=None)
mock_response2 = Mock(error=None)
# 模拟异步操作
future1 = asyncio.Future()
future1.set_result(mock_response1)
future2 = asyncio.Future()
future2.set_result(mock_response2)
mock_run_request.side_effect = [future1, future2]
mock_make_error.return_value = asyncio.Future()
mock_make_error.return_value.set_result(Mock())
await run_batch(mock_args)
# 验证日志记录
mock_logger.info.assert_any_call("concurrency: 512, workers: 2, max_concurrency: 256")
mock_logger.info.assert_any_call("Reading batch from input.jsonl...")
mock_logger.info.assert_any_call("Batch processing completed: 2 success, 0 errors")
# 验证文件写入
mock_write_file.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers")
@patch("fastdeploy.entrypoints.openai.run_batch.read_file")
@patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output")
@patch("fastdeploy.entrypoints.openai.run_batch.write_file")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_batch_unsupported_endpoint(
self, mock_logger, mock_write_file, mock_make_error, mock_read_file, mock_setup
):
"""测试不支持的端点"""
mock_args = Mock()
mock_args.input_file = "input.jsonl"
mock_args.output_file = "output.jsonl"
mock_args.output_tmp_dir = "/tmp"
mock_args.max_concurrency = 512
mock_args.workers = 1
mock_setup.return_value = (Mock(), Mock())
# 模拟不支持的URL
mock_read_file.return_value = '{"url": "/v1/unsupported", "custom_id": "1"}'
mock_make_error.return_value = asyncio.Future()
mock_make_error.return_value.set_result(Mock())
await run_batch(mock_args)
# 验证错误处理被调用
mock_make_error.assert_called_once()
mock_logger.info.assert_any_call("Batch processing completed: 0 success, 1 errors")
@patch("fastdeploy.entrypoints.openai.run_batch.setup_engine_and_handlers")
@patch("fastdeploy.entrypoints.openai.run_batch.read_file")
@patch("fastdeploy.entrypoints.openai.run_batch.make_async_error_request_output")
@patch("fastdeploy.entrypoints.openai.run_batch.write_file")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_batch_no_chat_handler_for_chat_completions(
self, mock_logger, mock_write_file, mock_make_error, mock_read_file, mock_setup
):
"""测试chat_handler为None时处理chat请求"""
mock_args = Mock()
mock_args.input_file = "input.jsonl"
mock_args.output_file = "output.jsonl"
mock_args.output_tmp_dir = "/tmp"
mock_args.max_concurrency = 512
mock_args.workers = 1
# 返回None作为chat_handler
mock_setup.return_value = (Mock(), None)
mock_read_file.return_value = '{"url": "/v1/chat/completions", "custom_id": "1"}'
mock_make_error.return_value = asyncio.Future()
mock_error_output = Mock()
mock_make_error.return_value.set_result(mock_error_output)
await run_batch(mock_args)
# 验证错误处理被调用
mock_make_error.assert_called_once_with(
unittest.mock.ANY, error_msg="The model does not support Chat Completions API"
)
mock_logger.info.assert_any_call("Batch processing completed: 0 success, 1 errors")
@patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server")
@patch("fastdeploy.entrypoints.openai.run_batch.ToolParserManager")
@patch("fastdeploy.entrypoints.openai.run_batch.init_engine")
@patch("fastdeploy.entrypoints.openai.run_batch.run_batch")
@patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_main_success(
self, mock_logger, mock_cleanup, mock_run_batch, mock_init_engine, mock_tool_parser, mock_retrieve_model
):
"""测试主函数成功执行"""
mock_args = Mock()
mock_args.workers = None
mock_args.max_num_seqs = 64
mock_args.model = "test-model"
mock_args.revision = "main"
mock_args.tool_parser_plugin = None
mock_retrieve_model.return_value = "retrieved-model"
mock_init_engine.return_value = True
await main(mock_args)
# 验证参数处理
self.assertEqual(mock_args.workers, 2)
self.assertEqual(mock_args.model, "retrieved-model")
mock_retrieve_model.assert_called_with("test-model", "main")
mock_init_engine.assert_called_with(mock_args)
mock_run_batch.assert_called_with(mock_args)
mock_cleanup.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server")
@patch("fastdeploy.entrypoints.openai.run_batch.ToolParserManager")
@patch("fastdeploy.entrypoints.openai.run_batch.init_engine")
@patch("fastdeploy.entrypoints.openai.run_batch.run_batch")
@patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_main_with_tool_parser_plugin(
self, mock_logger, mock_cleanup, mock_run_batch, mock_init_engine, mock_tool_parser, mock_retrieve_model
):
"""测试主函数使用tool_parser_plugin"""
mock_args = Mock()
mock_args.workers = 1
mock_args.max_num_seqs = 32
mock_args.model = "test-model"
mock_args.revision = "main"
mock_args.tool_parser_plugin = "test_plugin"
mock_retrieve_model.return_value = "retrieved-model"
mock_init_engine.return_value = True
await main(mock_args)
# 验证工具解析器插件被导入
mock_tool_parser.import_tool_parser.assert_called_once_with("test_plugin")
mock_init_engine.assert_called_with(mock_args)
mock_run_batch.assert_called_with(mock_args)
mock_cleanup.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.retrive_model_from_server")
@patch("fastdeploy.entrypoints.openai.run_batch.init_engine")
@patch("fastdeploy.entrypoints.openai.run_batch.cleanup_resources")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_main_init_engine_fails(self, mock_logger, mock_cleanup, mock_init_engine, mock_retrieve_model):
"""测试初始化引擎失败的情况"""
mock_args = Mock()
mock_args.workers = None
mock_args.max_num_seqs = 64
mock_args.model = "test-model"
mock_args.revision = "main"
mock_args.tool_parser_plugin = None
mock_retrieve_model.return_value = "retrieved-model"
mock_init_engine.return_value = False # 初始化失败
await main(mock_args)
# 验证没有运行批处理
mock_init_engine.assert_called_with(mock_args)
mock_cleanup.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_cleanup_resources_success(self, mock_logger):
"""测试资源清理成功"""
# 模拟全局变量
with (
patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", None),
patch("fastdeploy.entrypoints.openai.run_batch.engine_client", None),
):
await cleanup_resources()
# 验证日志记录
mock_logger.error.assert_not_called()
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_cleanup_resources_with_errors(self, mock_logger):
"""测试资源清理时出现错误"""
# 模拟有问题的引擎和客户端
mock_engine = Mock()
mock_engine._exit_sub_services = Mock(side_effect=Exception("Engine error"))
mock_client = Mock()
mock_client.zmq_client = Mock()
mock_client.zmq_client.close = Mock(side_effect=Exception("ZMQ error"))
mock_client.connection_manager = AsyncMock()
mock_client.connection_manager.close = AsyncMock(side_effect=Exception("Connection error"))
with (
patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine),
patch("fastdeploy.entrypoints.openai.run_batch.engine_client", mock_client),
):
await cleanup_resources()
# 验证错误被记录但不会抛出
self.assertEqual(mock_logger.error.call_count, 3)
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_cleanup_resources_partial_errors(self, mock_logger):
"""测试资源清理时部分组件出错"""
# 模拟只有引擎有问题的情况
mock_engine = Mock()
mock_engine._exit_sub_services = Mock(side_effect=Exception("Engine error"))
with (
patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine),
patch("fastdeploy.entrypoints.openai.run_batch.engine_client", None),
):
await cleanup_resources()
# 验证只有引擎错误被记录
mock_logger.error.assert_called_once()
mock_logger.error.assert_called_with("Error stopping engine: Engine error")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
@patch("gc.collect")
async def test_cleanup_resources_with_gc(self, mock_gc, mock_logger):
"""测试资源清理包括垃圾回收"""
# 模拟有引擎和客户端的情况
mock_engine = Mock()
mock_engine._exit_sub_services = Mock()
mock_client = Mock()
mock_client.zmq_client = Mock()
mock_client.zmq_client.close = Mock()
mock_client.connection_manager = AsyncMock()
mock_client.connection_manager.close = AsyncMock()
with (
patch("fastdeploy.entrypoints.openai.run_batch.llm_engine", mock_engine),
patch("fastdeploy.entrypoints.openai.run_batch.engine_client", mock_client),
):
await cleanup_resources()
# 验证垃圾回收被调用
mock_gc.assert_called_once()
mock_logger.error.assert_not_called()
class TestRunRequest(unittest.TestCase):
"""测试run_request函数"""
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@patch("fastdeploy.entrypoints.openai.run_batch.random_uuid")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_request_success_chat_completion(self, mock_logger, mock_random_uuid):
"""测试成功返回ChatCompletionResponse的情况"""
mock_random_uuid.side_effect = ["id1", "req1"]
# 模拟成功的响应
mock_response = Mock(spec=ChatCompletionResponse)
mock_engine = AsyncMock(return_value=mock_response)
mock_request = Mock()
mock_request.custom_id = "test-id"
mock_request.body = "test-body"
mock_tracker = Mock()
mock_semaphore = AsyncMock()
result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore)
# 验证结果
self.assertEqual(result.custom_id, "test-id")
self.assertEqual(result.response.status_code, 200)
self.assertEqual(result.response.body, mock_response)
self.assertIsNone(result.error)
mock_tracker.completed.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.random_uuid")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_request_error_response(self, mock_logger, mock_random_uuid):
"""测试返回ErrorResponse的情况"""
mock_random_uuid.side_effect = ["id2", "req2"]
# 模拟错误响应
mock_error = Mock(spec=ErrorResponse)
mock_engine = AsyncMock(return_value=mock_error)
mock_request = Mock()
mock_request.custom_id = "error-id"
mock_tracker = Mock()
mock_semaphore = AsyncMock()
result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore)
# 验证错误结果
self.assertEqual(result.response.status_code, 400)
self.assertEqual(result.error, mock_error)
mock_tracker.completed.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_request_stream_mode_error(self, mock_logger, mock_make_error):
"""测试流模式错误情况"""
# 模拟非ChatCompletionResponse和ErrorResponse的响应
mock_engine = AsyncMock(return_value="invalid_response")
mock_request = Mock()
mock_tracker = Mock()
mock_semaphore = AsyncMock()
mock_error_output = Mock()
mock_make_error.return_value = mock_error_output
result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore)
# 验证调用了错误处理函数
mock_make_error.assert_called_once_with(mock_request, "Request must not be sent in stream mode")
self.assertEqual(result, mock_error_output)
mock_tracker.completed.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output")
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
async def test_run_request_exception(self, mock_logger, mock_make_error):
"""测试异常情况"""
# 模拟抛出异常
mock_engine = AsyncMock(side_effect=Exception("Test error"))
mock_request = Mock()
mock_request.custom_id = "exception-id"
mock_tracker = Mock()
mock_semaphore = AsyncMock()
mock_error_output = Mock()
mock_make_error.return_value = mock_error_output
result = await run_request(mock_engine, mock_request, mock_tracker, mock_semaphore)
# 验证错误日志和错误处理
mock_logger.error.assert_called_once()
mock_make_error.assert_called_once_with(mock_request, "Request processing failed: Test error")
self.assertEqual(result, mock_error_output)
mock_tracker.completed.assert_called_once()
class TestDetermineProcessId(unittest.TestCase):
"""测试determine_process_id函数"""
@patch("multiprocessing.current_process")
@patch("os.getppid")
@patch("os.getpid")
def test_determine_process_id_main_process(self, mock_getpid, mock_getppid, mock_current_process):
"""测试主进程情况"""
mock_current_process.return_value.name = "MainProcess"
mock_getpid.return_value = 123
result = determine_process_id()
self.assertEqual(result, 123)
mock_getpid.assert_called_once()
mock_getppid.assert_not_called()
@patch("multiprocessing.current_process")
@patch("os.getppid")
@patch("os.getpid")
def test_determine_process_id_child_process(self, mock_getpid, mock_getppid, mock_current_process):
"""测试子进程情况"""
mock_current_process.return_value.name = "Process-1"
mock_getppid.return_value = 456
determine_process_id()
mock_getpid.assert_called_once()
class TestCreateModelPaths(unittest.TestCase):
"""测试create_model_paths函数"""
def test_create_model_paths_with_served_model_name(self):
"""测试提供served_model_name的情况"""
mock_args = Mock()
mock_args.served_model_name = "custom-model-name"
mock_args.model = "path/to/model"
result = create_model_paths(mock_args)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "custom-model-name")
self.assertEqual(result[0].model_path, "path/to/model")
self.assertTrue(result[0].verification)
def test_create_model_paths_without_served_model_name(self):
"""测试不提供served_model_name的情况"""
mock_args = Mock()
mock_args.served_model_name = None
mock_args.model = "path/to/model"
result = create_model_paths(mock_args)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "path/to/model")
self.assertEqual(result[0].model_path, "path/to/model")
self.assertFalse(result[0].verification)
class TestErrorRequestOutput(unittest.TestCase):
"""测试错误请求输出生成函数"""
def setUp(self):
# 设置异步测试循环
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@patch("fastdeploy.entrypoints.openai.run_batch.random_uuid")
def test_make_error_request_output_basic(self, mock_random_uuid):
"""测试基本功能"""
mock_random_uuid.side_effect = ["req123", "batch456"]
mock_request = Mock()
mock_request.custom_id = "test-id"
result = make_error_request_output(mock_request, "Test error")
# 验证基本属性
self.assertEqual(result.id, "fastdeploy-req123")
self.assertEqual(result.custom_id, "test-id")
self.assertEqual(result.error, "Test error")
self.assertEqual(result.response.status_code, HTTPStatus.BAD_REQUEST)
self.assertEqual(result.response.request_id, "fastdeploy-batch-batch456")
@patch("fastdeploy.entrypoints.openai.run_batch.make_error_request_output")
async def test_make_async_error_request_output(self, mock_make_error):
"""测试异步版本"""
expected_output = Mock()
mock_make_error.return_value = expected_output
mock_request = Mock()
mock_request.custom_id = "async-test"
result = await make_async_error_request_output(mock_request, "Async error")
self.assertEqual(result, expected_output)
mock_make_error.assert_called_once_with(mock_request, "Async error")
class TestFileOperations(unittest.TestCase):
"""测试文件操作相关函数"""
def setUp(self):
# 设置异步测试循环
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
@patch("aiohttp.ClientSession")
async def test_read_file_http(self, mock_session):
"""测试从HTTP URL读取文件"""
# 模拟响应
mock_resp = AsyncMock()
mock_resp.text = AsyncMock(return_value="HTTP content")
mock_session.return_value.__aenter__.return_value.get.return_value.__aenter__.return_value = mock_resp
result = await read_file("https://example.com/file.txt")
self.assertEqual(result, "HTTP content")
mock_session.assert_called_once()
def create_batch_outputs_from_jsonl(self, jsonl_text):
"""从 JSONL 文本创建 BatchRequestOutput 对象列表"""
batch_outputs = []
lines = jsonl_text.strip().split("\n")
for line in lines:
if line.strip():
data = json.loads(line)
# 解析 response 部分
response_data = data["response"]
body_data = response_data["body"]
# 创建 ChatMessage 对象
message_data = body_data["choices"][0]["message"]
chat_message = ChatMessage(
role=message_data["role"],
content=message_data["content"],
multimodal_content=message_data["multimodal_content"],
reasoning_content=message_data["reasoning_content"],
tool_calls=message_data["tool_calls"],
prompt_token_ids=message_data["prompt_token_ids"],
completion_token_ids=message_data["completion_token_ids"],
text_after_process=message_data["text_after_process"],
raw_prediction=message_data["raw_prediction"],
prompt_tokens=message_data["prompt_tokens"],
completion_tokens=message_data["completion_tokens"],
)
# 创建 ChatCompletionResponseChoice 对象
choice_data = body_data["choices"][0]
choice = ChatCompletionResponseChoice(
index=choice_data["index"],
message=chat_message,
logprobs=choice_data["logprobs"],
finish_reason=choice_data["finish_reason"],
)
# 创建 UsageInfo 对象
usage_data = body_data["usage"]
usage_info = UsageInfo(
prompt_tokens=usage_data["prompt_tokens"],
total_tokens=usage_data["total_tokens"],
completion_tokens=usage_data["completion_tokens"],
prompt_tokens_details=usage_data.get("prompt_tokens_details"),
)
# 创建 ChatCompletionResponse 对象
chat_completion_response = ChatCompletionResponse(
id=body_data["id"],
object=body_data["object"],
created=body_data["created"],
model=body_data["model"],
choices=[choice],
usage=usage_info,
)
# 创建 BatchResponseData 对象
batch_response_data = BatchResponseData(
status_code=response_data["status_code"],
request_id=response_data["request_id"],
body=chat_completion_response,
)
# 创建 BatchRequestOutput 对象
batch_output = BatchRequestOutput(
id=data["id"], custom_id=data["custom_id"], response=batch_response_data, error=data["error"]
)
batch_outputs.append(batch_output)
return batch_outputs
def test_write_local_file_basic(self):
"""测试基础功能:写入文件并验证内容"""
# 创建测试数据
batch_outputs = self.create_batch_outputs_from_jsonl(BATCH_RESPONSE)
# 创建临时文件
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file:
temp_path = temp_file.name
try:
# 异步调用被测函数
async def run_test():
await write_local_file(temp_path, batch_outputs)
self.loop.run_until_complete(run_test())
# 验证文件存在
self.assertTrue(os.path.exists(temp_path))
# 验证文件不为空
self.assertGreater(os.path.getsize(temp_path), 0)
# 读取并验证文件内容
with open(temp_path, "r", encoding="utf-8") as f:
written_lines = f.read().strip().split("\n")
# 验证行数匹配
self.assertEqual(len(written_lines), 2)
# 验证每行都是有效的 JSON
for i, line in enumerate(written_lines):
data = json.loads(line)
self.assertIn("id", data)
self.assertIn("custom_id", data)
self.assertIn("response", data)
self.assertIn("error", data)
# 验证关键字段
self.assertEqual(data["custom_id"], f"req-0000{i+1}")
self.assertEqual(data["response"]["status_code"], 200)
self.assertIn("body", data["response"])
self.assertIn("choices", data["response"]["body"])
print("✓ 基础功能测试通过")
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.unlink(temp_path)
def test_write_local_file_content_integrity(self):
"""测试内容完整性:验证写入的内容与原始数据一致"""
# 创建测试数据
batch_outputs = self.create_batch_outputs_from_jsonl(BATCH_RESPONSE)
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file:
temp_path = temp_file.name
try:
# 异步调用被测函数
async def run_test():
await write_local_file(temp_path, batch_outputs)
self.loop.run_until_complete(run_test())
# 读取写入的文件内容
with open(temp_path, "r", encoding="utf-8") as f:
written_content = f.read().strip()
# 解析原始数据
original_lines = BATCH_RESPONSE.strip().split("\n")
written_lines = written_content.split("\n")
# 验证行数一致
self.assertEqual(len(original_lines), len(written_lines))
# 验证每行的关键字段一致
for i, (orig_line, written_line) in enumerate(zip(original_lines, written_lines)):
orig_data = json.loads(orig_line)
written_data = json.loads(written_line)
# 比较关键标识字段
self.assertEqual(orig_data["id"], written_data["id"])
self.assertEqual(orig_data["custom_id"], written_data["custom_id"])
self.assertEqual(orig_data["response"]["status_code"], written_data["response"]["status_code"])
# 比较响应内容
orig_content = orig_data["response"]["body"]["choices"][0]["message"]["content"]
written_content = written_data["response"]["body"]["choices"][0]["message"]["content"]
# 内容应该一致
self.assertEqual(orig_content, written_content)
print("✓ 内容完整性测试通过")
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
def test_write_local_file_empty_list(self):
"""测试空列表处理"""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as temp_file:
temp_path = temp_file.name
try:
# 异步调用函数写入空列表
async def run_test():
await write_local_file(temp_path, [])
self.loop.run_until_complete(run_test())
# 验证文件存在但为空
self.assertTrue(os.path.exists(temp_path))
with open(temp_path, "r", encoding="utf-8") as f:
content = f.read()
self.assertEqual(content, "")
print("✓ 空列表处理测试通过")
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@patch("builtins.open", new_callable=mock_open, read_data="Local content")
async def test_read_file_local(self, mock_file):
"""测试从本地文件读取"""
result = await read_file("/local/path/file.txt")
self.assertEqual(result, "Local content")
mock_file.assert_called_once_with("/local/path/file.txt", encoding="utf-8")
@patch("builtins.open", new_callable=mock_open)
async def test_write_local_file(self, mock_file):
"""测试写入本地文件"""
# 创建模拟的batch outputs
mock_outputs = [
Mock(spec=BatchRequestOutput, model_dump_json=Mock(return_value='{"id": 1}')),
Mock(spec=BatchRequestOutput, model_dump_json=Mock(return_value='{"id": 2}')),
]
await write_local_file("/output/path.json", mock_outputs)
mock_file.assert_called_once_with("/output/path.json", "w", encoding="utf-8")
# 检查写入调用
handle = mock_file()
expected_calls = [unittest.mock.call.write('{"id": 1}\n'), unittest.mock.call.write('{"id": 2}\n')]
handle.write.assert_has_calls(expected_calls)
@patch("aiohttp.ClientSession")
async def test_upload_data_success(self, mock_session):
"""测试成功上传数据"""
mock_resp = Mock(status=200, text=Mock(return_value="OK"))
mock_session.return_value.__aenter__.return_value.put.return_value.__aenter__.return_value = mock_resp
# 测试从文件上传
with patch("builtins.open", mock_open(read_data=b"file content")):
await upload_data("https://example.com/upload", "/path/to/file", from_file=True)
# 测试直接上传数据
await upload_data("https://example.com/upload", "raw data", from_file=False)
self.assertEqual(mock_session.call_count, 2)
@patch("aiohttp.ClientSession")
@patch("asyncio.sleep", new_callable=AsyncMock)
async def test_upload_data_retry(self, mock_sleep, mock_session):
"""测试上传失败重试逻辑"""
# 模拟前两次失败,第三次成功
mock_resp_fail = Mock(status=500, text=Mock(return_value="Server Error"))
mock_resp_success = Mock(status=200, text=Mock(return_value="OK"))
mock_session.return_value.__aenter__.return_value.put.side_effect = [
Exception("First failure"),
mock_resp_fail,
mock_resp_success,
]
# 这次应该成功,经过两次重试
with patch("builtins.open", mock_open(read_data=b"content")):
await upload_data("https://example.com/upload", "/path/to/file", from_file=True)
# 检查重试次数
self.assertEqual(mock_sleep.call_count, 2)
self.assertEqual(mock_session.return_value.__aenter__.return_value.put.call_count, 3)
@patch("aiohttp.ClientSession")
async def test_upload_data_failure(self, mock_session):
"""测试上传最终失败"""
mock_session.return_value.__aenter__.return_value.put.side_effect = Exception("Persistent failure")
with patch("builtins.open", mock_open(read_data=b"content")):
with self.assertRaises(Exception) as context:
await upload_data("https://example.com/upload", "/path/to/file", from_file=True)
self.assertIn("Failed to upload data", str(context.exception))
@patch("fastdeploy.entrypoints.openai.run_batch.upload_data")
@patch("fastdeploy.entrypoints.openai.run_batch.write_local_file")
async def test_write_file_http_with_buffer(self, mock_write_local, mock_upload):
"""测试HTTP输出写入到内存缓冲区"""
mock_outputs = [Mock(spec=BatchRequestOutput)]
await write_file("https://example.com/output", mock_outputs, output_tmp_dir=None)
# 应该调用upload_data而不是write_local_file
mock_upload.assert_called_once()
mock_write_local.assert_not_called()
@patch("fastdeploy.entrypoints.openai.run_batch.upload_data")
@patch("tempfile.NamedTemporaryFile")
@patch("fastdeploy.entrypoints.openai.run_batch.write_local_file")
async def test_write_file_http_with_tempfile(self, mock_write_local, mock_tempfile, mock_upload):
"""测试HTTP输出写入到临时文件"""
# 模拟临时文件
mock_file = Mock()
mock_file.name = "/tmp/tempfile.json"
mock_tempfile.return_value.__enter__.return_value = mock_file
mock_outputs = [Mock(spec=BatchRequestOutput)]
await write_file("https://example.com/output", mock_outputs, output_tmp_dir="/tmp")
mock_tempfile.assert_called_once()
mock_write_local.assert_called_once_with(mock_file.name, mock_outputs)
mock_upload.assert_called_once_with("https://example.com/output", mock_file.name, from_file=True)
@patch("fastdeploy.entrypoints.openai.run_batch.write_local_file")
async def test_write_file_local(self, mock_write_local):
"""测试本地文件输出"""
mock_outputs = [Mock(spec=BatchRequestOutput)]
await write_file("/local/output.json", mock_outputs, output_tmp_dir="/tmp")
mock_write_local.assert_called_once_with("/local/output.json", mock_outputs)
class TestUtilityFunctions(unittest.TestCase):
"""测试工具函数"""
def test_random_uuid(self):
"""测试生成随机UUID"""
uuid1 = random_uuid()
uuid2 = random_uuid()
self.assertEqual(len(uuid1), 32)
self.assertTrue(all(c in "0123456789abcdef" for c in uuid1))
self.assertNotEqual(uuid1, uuid2)
class TestBatchProgressTracker(unittest.TestCase):
def test_submitted_increments_total(self):
tracker = BatchProgressTracker()
self.assertEqual(tracker._total, 0)
tracker.submitted()
self.assertEqual(tracker._total, 1)
tracker.submitted()
self.assertEqual(tracker._total, 2)
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
def test_completed_increments_completed_and_logs(self, mock_logger):
tracker = BatchProgressTracker()
tracker._total = 20
# 调用 10 次 -> 应该触发一次日志 (log_interval=2)
for _ in range(10):
tracker.completed()
self.assertEqual(tracker._completed, 10)
mock_logger.info.assert_called() # 至少被调用一次
args, _ = mock_logger.info.call_args
self.assertIn("Progress: 10/20", args[0])
@patch("fastdeploy.entrypoints.openai.run_batch.tqdm")
def test_completed_updates_pbar(self, mock_tqdm):
mock_pbar = MagicMock()
mock_tqdm.return_value = mock_pbar
tracker = BatchProgressTracker()
tracker._total = 5
tracker.pbar() # 初始化 pbar
tracker.completed()
mock_pbar.update.assert_called_once()
@patch("fastdeploy.entrypoints.openai.run_batch.tqdm")
def test_pbar_returns_tqdm(self, mock_tqdm):
mock_pbar = MagicMock(spec=tqdm)
mock_tqdm.return_value = mock_pbar
tracker = BatchProgressTracker()
tracker._total = 3
result = tracker.pbar()
self.assertIs(result, mock_pbar)
mock_tqdm.assert_called_once_with(
total=3,
unit="req",
desc="Running batch",
mininterval=10,
bar_format=_BAR_FORMAT,
)
class TestBatchProgressTrackerExtended(unittest.TestCase):
"""扩展的BatchProgressTracker测试"""
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
def test_completed_with_pbar_no_log(self, mock_logger):
"""测试有进度条时的completed方法不触发日志记录"""
tracker = BatchProgressTracker()
tracker._total = 100 # 设置较大的总数,使得第一次完成不会触发日志
tracker._pbar = Mock()
tracker.completed() # 完成1个1/100=1%,不会触发日志记录
tracker._pbar.update.assert_called_once()
mock_logger.info.assert_not_called() # 不应该记录日志
@patch("fastdeploy.entrypoints.openai.run_batch.console_logger")
def test_completed_log_interval(self, mock_logger):
"""测试日志间隔"""
tracker = BatchProgressTracker()
tracker._total = 100
tracker._last_log_count = 0
# 触发日志记录每10个记录一次
for i in range(1, 21):
tracker.completed()
if i % 10 == 0:
mock_logger.info.assert_called_with(f"Progress: {i}/100 requests completed")
class TestFastDeployBatch(unittest.TestCase):
"""测试 FastDeploy 批处理功能的 unittest 测试类"""
def setUp(self):
"""每个测试方法执行前的准备工作"""
self.model_path = "baidu/ERNIE-4.5-0.3B-PT"
self.base_command = ["fastdeploy", "run-batch"]
self.run_batch_command = ["python", "fastdeploy/entrypoints/openai/run_batch.py"]
def run_fastdeploy_command(self, input_content, port=None):
"""运行 FastDeploy 命令的辅助方法"""
if port is None:
port = "1231"
with tempfile.NamedTemporaryFile("w") as input_file, tempfile.NamedTemporaryFile("r") as output_file:
input_file.write(input_content)
input_file.flush()
param = [
"-i",
input_file.name,
"-o",
output_file.name,
"--model",
self.model_path,
"--cache-queue-port",
port,
"--tensor-parallel-size",
"1",
"--quantization",
"wint4",
"--max-model-len",
"4192",
"--max-num-seqs",
"64",
"--load-choices",
"default_v1",
"--engine-worker-queue-port",
"3672",
]
# command = self.base_command + param
run_batch_command = self.run_batch_command + param
proc = subprocess.Popen(run_batch_command)
proc.communicate()
return_code = proc.wait()
# 读取输出文件内容
output_file.seek(0)
contents = output_file.read()
return return_code, contents, proc
def test_completions(self):
"""测试正常的批量chat请求"""
return_code, contents, proc = self.run_fastdeploy_command(INPUT_BATCH, port="2235")
self.assertEqual(return_code, 0, f"进程返回非零码: {return_code}, 进程信息: {proc}")
# 验证每行输出都符合 OpenAI API 格式
lines = contents.strip().split("\n")
for line in lines:
if line: # 跳过空行
# 验证应该抛出异常如果 schema 错误
try:
BatchRequestOutput.model_validate_json(line)
except Exception as e:
self.fail(f"输出格式验证失败: {e}\n行内容: {line}")
def test_vaild_input(self):
"""测试输入数据格式的正确性"""
return_code, contents, proc = self.run_fastdeploy_command(INVALID_INPUT_BATCH)
self.assertNotEqual(return_code, 0, f"进程返回非零码: {return_code}, 进程信息: {proc}")
if __name__ == "__main__":
unittest.main(verbosity=2)