mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[BugFix] fix max streaming tokens invalid (#3789)
This commit is contained in:
@@ -183,6 +183,8 @@ class OpenAIServingChat:
|
||||
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
|
||||
) # dierctly passed & passed in metadata
|
||||
|
||||
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
||||
|
||||
enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None
|
||||
if enable_thinking is None:
|
||||
enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None
|
||||
@@ -370,11 +372,6 @@ class OpenAIServingChat:
|
||||
api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}")
|
||||
choices = []
|
||||
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
if include_usage:
|
||||
completion_tokens = previous_num_tokens
|
||||
usage = UsageInfo(
|
||||
|
@@ -331,6 +331,7 @@ class OpenAIServingCompletion:
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.suffix or {}).get("max_streaming_response_tokens", 1)
|
||||
) # dierctly passed & passed in suffix
|
||||
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
||||
choices = []
|
||||
chunk = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
@@ -461,10 +462,6 @@ class OpenAIServingCompletion:
|
||||
)
|
||||
yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}")
|
||||
if choices:
|
||||
chunk.choices = choices
|
||||
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
|
||||
choices = []
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}")
|
||||
|
279
tests/entrypoints/openai/test_max_streaming_tokens.py
Normal file
279
tests/entrypoints/openai/test_max_streaming_tokens.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import json
|
||||
import unittest
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
|
||||
|
||||
class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
self.engine_client = Mock()
|
||||
self.engine_client.connection_initialized = False
|
||||
self.engine_client.connection_manager = AsyncMock()
|
||||
self.engine_client.connection_manager.initialize = AsyncMock()
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock()
|
||||
self.engine_client.connection_manager.cleanup_request = AsyncMock()
|
||||
self.engine_client.semaphore = Mock()
|
||||
self.engine_client.semaphore.acquire = AsyncMock()
|
||||
self.engine_client.semaphore.release = Mock()
|
||||
self.engine_client.data_processor = Mock()
|
||||
self.engine_client.is_master = True
|
||||
|
||||
self.chat_serving = OpenAIServingChat(
|
||||
engine_client=self.engine_client,
|
||||
models=None,
|
||||
pid=123,
|
||||
ips=None,
|
||||
max_waiting_time=30,
|
||||
chat_template="default",
|
||||
enable_mm_output=False,
|
||||
tokenizer_base_url=None,
|
||||
)
|
||||
|
||||
self.completion_serving = OpenAIServingCompletion(
|
||||
engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30
|
||||
)
|
||||
|
||||
def test_metadata_parameter_setting(self):
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True,
|
||||
metadata={"max_streaming_response_tokens": 100},
|
||||
)
|
||||
|
||||
max_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
|
||||
)
|
||||
|
||||
self.assertEqual(max_tokens, 100)
|
||||
|
||||
def test_default_value(self):
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model", messages=[{"role": "user", "content": "Hello"}], stream=True
|
||||
)
|
||||
|
||||
max_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
|
||||
)
|
||||
|
||||
self.assertEqual(max_tokens, 1)
|
||||
|
||||
def test_edge_case_zero_value(self):
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True,
|
||||
max_streaming_response_tokens=0,
|
||||
)
|
||||
|
||||
max_streaming_response_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
else (request.metadata or {}).get("max_streaming_response_tokens", 1)
|
||||
)
|
||||
max_streaming_response_tokens = max(1, max_streaming_response_tokens)
|
||||
|
||||
self.assertEqual(max_streaming_response_tokens, 1)
|
||||
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
|
||||
async def test_integration_with_chat_stream_generator(self, mock_processor_class, mock_logger):
|
||||
response_data = [
|
||||
{
|
||||
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
|
||||
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.2, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [3], "text": "c", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.3, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [4], "text": "d", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.4, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [5], "text": "e", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.5, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [6], "text": "f", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.6, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
|
||||
"finished": True,
|
||||
},
|
||||
]
|
||||
|
||||
mock_response_queue = AsyncMock()
|
||||
mock_response_queue.get.side_effect = response_data
|
||||
|
||||
mock_dealer = Mock()
|
||||
mock_dealer.write = Mock()
|
||||
|
||||
# Mock the connection manager call
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(
|
||||
return_value=(mock_dealer, mock_response_queue)
|
||||
)
|
||||
|
||||
mock_processor_instance = Mock()
|
||||
|
||||
async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
yield response
|
||||
|
||||
mock_processor_instance.process_response_chat = mock_process_response_chat_single
|
||||
mock_processor_instance.enable_multimodal_content = Mock(return_value=False)
|
||||
mock_processor_class.return_value = mock_processor_instance
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True,
|
||||
max_streaming_response_tokens=3,
|
||||
)
|
||||
|
||||
generator = self.chat_serving.chat_completion_stream_generator(
|
||||
request=request,
|
||||
request_id="test-request-id",
|
||||
model_name="test-model",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
text_after_process="Hello",
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in generator:
|
||||
chunks.append(chunk)
|
||||
|
||||
self.assertGreater(len(chunks), 0, "No chucks!")
|
||||
|
||||
parsed_chunks = []
|
||||
for i, chunk_str in enumerate(chunks):
|
||||
if i == 0:
|
||||
continue
|
||||
if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"):
|
||||
json_part = chunk_str[6:-2]
|
||||
if json_part == "[DONE]":
|
||||
parsed_chunks.append({"type": "done", "raw": chunk_str})
|
||||
break
|
||||
try:
|
||||
chunk_dict = json.loads(json_part)
|
||||
parsed_chunks.append(chunk_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}")
|
||||
else:
|
||||
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
|
||||
for chunk_dict in parsed_chunks:
|
||||
choices_list = chunk_dict["choices"]
|
||||
if choices_list[-1].get("finish_reason") is not None:
|
||||
break
|
||||
else:
|
||||
self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices")
|
||||
|
||||
found_done = any("[DONE]" in chunk for chunk in chunks)
|
||||
self.assertTrue(found_done, "Not Receive '[DONE]'")
|
||||
|
||||
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
|
||||
async def test_integration_with_completion_stream_generator(self, mock_logger):
|
||||
response_data = [
|
||||
[
|
||||
{
|
||||
"request_id": "test-request-id-0",
|
||||
"outputs": {"token_ids": [1], "text": "a", "top_logprobs": None},
|
||||
"metrics": {"first_token_time": 0.1, "inference_start_time": 0.1},
|
||||
"finished": False,
|
||||
},
|
||||
{
|
||||
"request_id": "test-request-id-0",
|
||||
"outputs": {"token_ids": [2], "text": "b", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.2, "first_token_time": None},
|
||||
"finished": False,
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"request_id": "test-request-id-0",
|
||||
"outputs": {"token_ids": [7], "text": "g", "top_logprobs": None},
|
||||
"metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1},
|
||||
"finished": True,
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
mock_response_queue = AsyncMock()
|
||||
mock_response_queue.get.side_effect = response_data
|
||||
|
||||
mock_dealer = Mock()
|
||||
mock_dealer.write = Mock()
|
||||
|
||||
# Mock the connection manager call
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(
|
||||
return_value=(mock_dealer, mock_response_queue)
|
||||
)
|
||||
|
||||
request = CompletionRequest(model="test-model", prompt="Hello", stream=True, max_streaming_response_tokens=3)
|
||||
|
||||
generator = self.completion_serving.completion_stream_generator(
|
||||
request=request,
|
||||
num_choices=1,
|
||||
request_id="test-request-id",
|
||||
model_name="test-model",
|
||||
created_time=11,
|
||||
prompt_batched_token_ids=[[1, 2, 3]],
|
||||
text_after_process_list=["Hello"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in generator:
|
||||
chunks.append(chunk)
|
||||
|
||||
self.assertGreater(len(chunks), 0, "No chucks!")
|
||||
|
||||
parsed_chunks = []
|
||||
for i, chunk_str in enumerate(chunks):
|
||||
if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"):
|
||||
json_part = chunk_str[6:-2]
|
||||
if json_part == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk_dict = json.loads(json_part)
|
||||
parsed_chunks.append(chunk_dict)
|
||||
except json.JSONDecodeError as e:
|
||||
self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}")
|
||||
else:
|
||||
self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}")
|
||||
self.assertEqual(len(parsed_chunks), 1)
|
||||
for chunk_dict in parsed_chunks:
|
||||
choices_list = chunk_dict["choices"]
|
||||
self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices")
|
||||
self.assertEqual(
|
||||
choices_list[-1].get("finish_reason"), "stop", f"Chunk {chunk_dict} should has stop reason"
|
||||
)
|
||||
|
||||
found_done = any("[DONE]" in chunk for chunk in chunks)
|
||||
self.assertTrue(found_done, "Not Receive '[DONE]'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user