Files
FastDeploy/tests/entrypoints/openai/test_wrap_streaming_generator.py
Echo-Nie 1b1bfab341 [CI] Add unittest (#5328)
* add test_worker_eplb

* remove tesnsor_wise_fp8

* add copyright
2025-12-09 19:19:42 +08:00

158 lines
4.6 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import time
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
# Mock the argument parser and model loading before importing api_server
with (
patch("fastdeploy.utils.FlexibleArgumentParser.parse_args") as mock_parse_args,
patch("fastdeploy.utils.retrive_model_from_server") as mock_retrive_model,
):
mock_args = SimpleNamespace(
workers=1,
model="test-model",
revision=None,
chat_template=None,
tool_parser_plugin=None,
max_concurrency=100, # Add required attribute
max_num_seqs=100,
tensor_parallel_size=1,
data_parallel_size=1,
enable_expert_parallel=False,
enable_logprob=False,
enable_early_stop=False,
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_num_partial_prefills=0,
max_long_partial_prefills=0,
long_prefill_token_threshold=0,
cache_transfer_protocol=None,
scheduler_name=None,
scheduler_host=None,
scheduler_port=None,
scheduler_db=None,
scheduler_password=None,
scheduler_topic=None,
api_key=None,
)
mock_parse_args.return_value = mock_args
mock_retrive_model.return_value = "test-model" # Just return the model name without downloading
from fastdeploy.entrypoints.openai.api_server import wrap_streaming_generator
@pytest.mark.asyncio
async def test_wrap_streaming_generator_normal_flow():
"""Test normal streaming generation flow"""
async def mock_generator():
yield "chunk1"
yield "chunk2"
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk1", "chunk2"]
@pytest.mark.asyncio
async def test_wrap_streaming_generator_exception_handling():
"""Test exception handling in wrapped generator"""
async def mock_generator():
yield "chunk1"
raise ValueError("test error")
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
with pytest.raises(ValueError, match="test error"):
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk1"]
@pytest.mark.asyncio
async def test_wrap_streaming_generator_semaphore_release():
"""Test semaphore is released after generation"""
mock_semaphore = MagicMock()
async def mock_generator():
yield "chunk"
# Patch the global connection_semaphore
with patch("fastdeploy.entrypoints.openai.api_server.connection_semaphore", mock_semaphore):
wrapped = wrap_streaming_generator(mock_generator())
async for _ in wrapped():
pass
mock_semaphore.release.assert_called_once()
@pytest.mark.asyncio
async def test_wrap_streaming_generator_span_recording():
"""Test span recording functionality"""
mock_span = MagicMock()
mock_span.is_recording.return_value = True
async def mock_generator():
yield "chunk1"
yield "chunk2"
# Patch trace.get_current_span
with patch("opentelemetry.trace.get_current_span", return_value=mock_span):
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
# Verify span events were recorded
mock_span.add_event.assert_any_call("first_chunk", {"time": pytest.approx(time.time(), abs=1)})
mock_span.add_event.assert_any_call(
"last_chunk", {"time": pytest.approx(time.time(), abs=1), "total_chunk": 2}
)
@pytest.mark.asyncio
async def test_wrap_streaming_generator_no_span():
"""Test behavior when no span is active"""
async def mock_generator():
yield "chunk"
# Patch trace.get_current_span to return None
with patch("opentelemetry.trace.get_current_span", return_value=None):
wrapped = wrap_streaming_generator(mock_generator())
chunks = []
async for chunk in wrapped():
chunks.append(chunk)
assert chunks == ["chunk"]