Files
FastDeploy/tests/entrypoints/openai/test_wrap_streaming_generator.py
kxz2002 87911b7cf1 [Feature] Enable FastDeploy to support adding the “--api-key” authentication parameter. (#4806)
* add api key initial commit

* add unit test

* modify unit test

* move middleware to a single file and add unit tests
2025-11-08 18:24:02 +08:00

142 lines
4.0 KiB
Python

import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
# Mock the argument parser and model loading before importing api_server
with (
patch("fastdeploy.utils.FlexibleArgumentParser.parse_args") as mock_parse_args,
patch("fastdeploy.utils.retrive_model_from_server") as mock_retrive_model,
):
mock_args = SimpleNamespace(
workers=1,
model="test-model",
revision=None,
chat_template=None,
tool_parser_plugin=None,
max_concurrency=100, # Add required attribute
max_num_seqs=100,
tensor_parallel_size=1,
data_parallel_size=1,
enable_expert_parallel=False,
enable_logprob=False,
enable_early_stop=False,
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_num_partial_prefills=0,
max_long_partial_prefills=0,
long_prefill_token_threshold=0,
cache_transfer_protocol=None,
scheduler_name=None,
scheduler_host=None,
scheduler_port=None,
scheduler_db=None,
scheduler_password=None,
scheduler_topic=None,
api_key=None,
)
mock_parse_args.return_value = mock_args
mock_retrive_model.return_value = "test-model" # Just return the model name without downloading
from fastdeploy.entrypoints.openai.api_server import wrap_streaming_generator
@pytest.mark.asyncio
async def test_wrap_streaming_generator_normal_flow():
"""Test normal streaming generation flow"""
async def mock_generator():
yield "chunk1"
yield "chunk2"
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk1", "chunk2"]
@pytest.mark.asyncio
async def test_wrap_streaming_generator_exception_handling():
"""Test exception handling in wrapped generator"""
async def mock_generator():
yield "chunk1"
raise ValueError("test error")
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
with pytest.raises(ValueError, match="test error"):
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk1"]
@pytest.mark.asyncio
async def test_wrap_streaming_generator_semaphore_release():
"""Test semaphore is released after generation"""
mock_semaphore = MagicMock()
async def mock_generator():
yield "chunk"
# Patch the global connection_semaphore
with patch("fastdeploy.entrypoints.openai.api_server.connection_semaphore", mock_semaphore):
wrapped = wrap_streaming_generator(mock_generator())
async for _ in wrapped():
pass
mock_semaphore.release.assert_called_once()
@pytest.mark.asyncio
async def test_wrap_streaming_generator_span_recording():
"""Test span recording functionality"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
async def mock_generator():
yield "chunk1"
yield "chunk2"
# Patch trace.get_current_span
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
# Verify span events were recorded
mock_span.add_event.assert_any_call("first_chunk", {"time": pytest.approx(time.time(), abs=1)})
mock_span.add_event.assert_any_call(
"last_chunk", {"time": pytest.approx(time.time(), abs=1), "total_chunk": 2}
)
@pytest.mark.asyncio
async def test_wrap_streaming_generator_no_span():
"""Test behavior when no span is active"""
async def mock_generator():
yield "chunk"
# Patch trace.get_current_span to return None
with patch("opentelemetry.trace.get_current_span", return_value=None):
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk"]