mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* add api key initial commit * add unit test * modify unit test * move middleware to a single file and add unit tests
142 lines
4.0 KiB
Python
142 lines
4.0 KiB
Python
import time
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# Mock the argument parser and model loading before importing api_server
|
|
with (
|
|
patch("fastdeploy.utils.FlexibleArgumentParser.parse_args") as mock_parse_args,
|
|
patch("fastdeploy.utils.retrive_model_from_server") as mock_retrive_model,
|
|
):
|
|
|
|
mock_args = SimpleNamespace(
|
|
workers=1,
|
|
model="test-model",
|
|
revision=None,
|
|
chat_template=None,
|
|
tool_parser_plugin=None,
|
|
max_concurrency=100, # Add required attribute
|
|
max_num_seqs=100,
|
|
tensor_parallel_size=1,
|
|
data_parallel_size=1,
|
|
enable_expert_parallel=False,
|
|
enable_logprob=False,
|
|
enable_early_stop=False,
|
|
enable_prefix_caching=False,
|
|
enable_chunked_prefill=False,
|
|
max_num_partial_prefills=0,
|
|
max_long_partial_prefills=0,
|
|
long_prefill_token_threshold=0,
|
|
cache_transfer_protocol=None,
|
|
scheduler_name=None,
|
|
scheduler_host=None,
|
|
scheduler_port=None,
|
|
scheduler_db=None,
|
|
scheduler_password=None,
|
|
scheduler_topic=None,
|
|
api_key=None,
|
|
)
|
|
mock_parse_args.return_value = mock_args
|
|
mock_retrive_model.return_value = "test-model" # Just return the model name without downloading
|
|
|
|
from fastdeploy.entrypoints.openai.api_server import wrap_streaming_generator
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrap_streaming_generator_normal_flow():
|
|
"""Test normal streaming generation flow"""
|
|
|
|
async def mock_generator():
|
|
yield "chunk1"
|
|
yield "chunk2"
|
|
|
|
wrapped = wrap_streaming_generator(mock_generator())
|
|
|
|
chunks = []
|
|
async for chunk in wrapped():
|
|
chunks.append(chunk)
|
|
|
|
assert chunks == ["chunk1", "chunk2"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrap_streaming_generator_exception_handling():
|
|
"""Test exception handling in wrapped generator"""
|
|
|
|
async def mock_generator():
|
|
yield "chunk1"
|
|
raise ValueError("test error")
|
|
|
|
wrapped = wrap_streaming_generator(mock_generator())
|
|
|
|
chunks = []
|
|
with pytest.raises(ValueError, match="test error"):
|
|
async for chunk in wrapped():
|
|
chunks.append(chunk)
|
|
|
|
assert chunks == ["chunk1"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrap_streaming_generator_semaphore_release():
|
|
"""Test semaphore is released after generation"""
|
|
|
|
mock_semaphore = MagicMock()
|
|
|
|
async def mock_generator():
|
|
yield "chunk"
|
|
|
|
# Patch the global connection_semaphore
|
|
with patch("fastdeploy.entrypoints.openai.api_server.connection_semaphore", mock_semaphore):
|
|
wrapped = wrap_streaming_generator(mock_generator())
|
|
|
|
async for _ in wrapped():
|
|
pass
|
|
|
|
mock_semaphore.release.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrap_streaming_generator_span_recording():
|
|
"""Test span recording functionality"""
|
|
|
|
mock_span = MagicMock()
|
|
mock_span.is_recording.return_value = True
|
|
|
|
async def mock_generator():
|
|
yield "chunk1"
|
|
yield "chunk2"
|
|
|
|
# Patch trace.get_current_span
|
|
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
|
wrapped = wrap_streaming_generator(mock_generator())
|
|
|
|
chunks = []
|
|
async for chunk in wrapped():
|
|
chunks.append(chunk)
|
|
|
|
# Verify span events were recorded
|
|
mock_span.add_event.assert_any_call("first_chunk", {"time": pytest.approx(time.time(), abs=1)})
|
|
mock_span.add_event.assert_any_call(
|
|
"last_chunk", {"time": pytest.approx(time.time(), abs=1), "total_chunk": 2}
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrap_streaming_generator_no_span():
|
|
"""Test behavior when no span is active"""
|
|
|
|
async def mock_generator():
|
|
yield "chunk"
|
|
|
|
# Patch trace.get_current_span to return None
|
|
with patch("opentelemetry.trace.get_current_span", return_value=None):
|
|
wrapped = wrap_streaming_generator(mock_generator())
|
|
|
|
chunks = []
|
|
async for chunk in wrapped():
|
|
chunks.append(chunk)
|
|
|
|
assert chunks == ["chunk"]
|