diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index df074771c..186e1a0f5 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -183,6 +183,8 @@ class OpenAIServingChat: else (request.metadata or {}).get("max_streaming_response_tokens", 1) ) # dierctly passed & passed in metadata + max_streaming_response_tokens = max(1, max_streaming_response_tokens) + enable_thinking = request.chat_template_kwargs.get("enable_thinking") if request.chat_template_kwargs else None if enable_thinking is None: enable_thinking = request.metadata.get("enable_thinking") if request.metadata else None @@ -370,11 +372,6 @@ class OpenAIServingChat: api_server_logger.info(f"Chat Streaming response last send: {chunk.model_dump_json()}") choices = [] - if choices: - chunk.choices = choices - yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - choices = [] - if include_usage: completion_tokens = previous_num_tokens usage = UsageInfo( diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 059c22da3..75812d58c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -331,6 +331,7 @@ class OpenAIServingCompletion: if request.max_streaming_response_tokens is not None else (request.suffix or {}).get("max_streaming_response_tokens", 1) ) # dierctly passed & passed in suffix + max_streaming_response_tokens = max(1, max_streaming_response_tokens) choices = [] chunk = CompletionStreamResponse( id=request_id, @@ -461,10 +462,6 @@ class OpenAIServingCompletion: ) yield f"data: {usage_chunk.model_dump_json(exclude_unset=True)}\n\n" api_server_logger.info(f"Completion Streaming response last send: {chunk.model_dump_json()}") - if choices: - chunk.choices = choices - yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - choices = [] except Exception as e: api_server_logger.error(f"Error in completion_stream_generator: {e}, {str(traceback.format_exc())}") diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py new file mode 100644 index 000000000..fe48be8c4 --- /dev/null +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -0,0 +1,279 @@ +import json +import unittest +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, Mock, patch + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, +) +from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion + + +class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.engine_client = Mock() + self.engine_client.connection_initialized = False + self.engine_client.connection_manager = AsyncMock() + self.engine_client.connection_manager.initialize = AsyncMock() + self.engine_client.connection_manager.get_connection = AsyncMock() + self.engine_client.connection_manager.cleanup_request = AsyncMock() + self.engine_client.semaphore = Mock() + self.engine_client.semaphore.acquire = AsyncMock() + self.engine_client.semaphore.release = Mock() + self.engine_client.data_processor = Mock() + self.engine_client.is_master = True + + self.chat_serving = OpenAIServingChat( + engine_client=self.engine_client, + models=None, + pid=123, + ips=None, + max_waiting_time=30, + chat_template="default", + enable_mm_output=False, + tokenizer_base_url=None, + ) + + self.completion_serving = OpenAIServingCompletion( + engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30 + ) + + def test_metadata_parameter_setting(self): + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + stream=True, + metadata={"max_streaming_response_tokens": 100}, + ) + + max_tokens = ( + request.max_streaming_response_tokens + if request.max_streaming_response_tokens is not None + else (request.metadata or {}).get("max_streaming_response_tokens", 1) + ) + + self.assertEqual(max_tokens, 100) + + def test_default_value(self): + request = ChatCompletionRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}], stream=True + ) + + max_tokens = ( + request.max_streaming_response_tokens + if request.max_streaming_response_tokens is not None + else (request.metadata or {}).get("max_streaming_response_tokens", 1) + ) + + self.assertEqual(max_tokens, 1) + + def test_edge_case_zero_value(self): + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + stream=True, + max_streaming_response_tokens=0, + ) + + max_streaming_response_tokens = ( + request.max_streaming_response_tokens + if request.max_streaming_response_tokens is not None + else (request.metadata or {}).get("max_streaming_response_tokens", 1) + ) + max_streaming_response_tokens = max(1, max_streaming_response_tokens) + + self.assertEqual(max_streaming_response_tokens, 1) + + @patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger") + @patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor") + async def test_integration_with_chat_stream_generator(self, mock_processor_class, mock_logger): + response_data = [ + { + "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None}, + "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, + "finished": False, + }, + { + "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None}, + "metrics": {"arrival_time": 0.2, "first_token_time": None}, + "finished": False, + }, + { + "outputs": {"token_ids": [3], "text": "c", "top_logprobs": None}, + "metrics": {"arrival_time": 0.3, "first_token_time": None}, + "finished": False, + }, + { + "outputs": {"token_ids": [4], "text": "d", "top_logprobs": None}, + "metrics": {"arrival_time": 0.4, "first_token_time": None}, + "finished": False, + }, + { + "outputs": {"token_ids": [5], "text": "e", "top_logprobs": None}, + "metrics": {"arrival_time": 0.5, "first_token_time": None}, + "finished": False, + }, + { + "outputs": {"token_ids": [6], "text": "f", "top_logprobs": None}, + "metrics": {"arrival_time": 0.6, "first_token_time": None}, + "finished": False, + }, + { + "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None}, + "metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1}, + "finished": True, + }, + ] + + mock_response_queue = AsyncMock() + mock_response_queue.get.side_effect = response_data + + mock_dealer = Mock() + mock_dealer.write = Mock() + + # Mock the connection manager call + self.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + mock_processor_instance = Mock() + + async def mock_process_response_chat_single(response, stream, enable_thinking, include_stop_str_in_output): + yield response + + mock_processor_instance.process_response_chat = mock_process_response_chat_single + mock_processor_instance.enable_multimodal_content = Mock(return_value=False) + mock_processor_class.return_value = mock_processor_instance + + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + stream=True, + max_streaming_response_tokens=3, + ) + + generator = self.chat_serving.chat_completion_stream_generator( + request=request, + request_id="test-request-id", + model_name="test-model", + prompt_token_ids=[1, 2, 3], + text_after_process="Hello", + ) + + chunks = [] + async for chunk in generator: + chunks.append(chunk) + + self.assertGreater(len(chunks), 0, "No chucks!") + + parsed_chunks = [] + for i, chunk_str in enumerate(chunks): + if i == 0: + continue + if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"): + json_part = chunk_str[6:-2] + if json_part == "[DONE]": + parsed_chunks.append({"type": "done", "raw": chunk_str}) + break + try: + chunk_dict = json.loads(json_part) + parsed_chunks.append(chunk_dict) + except json.JSONDecodeError as e: + self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}") + else: + self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") + for chunk_dict in parsed_chunks: + choices_list = chunk_dict["choices"] + if choices_list[-1].get("finish_reason") is not None: + break + else: + self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices") + + found_done = any("[DONE]" in chunk for chunk in chunks) + self.assertTrue(found_done, "Not Receive '[DONE]'") + + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_integration_with_completion_stream_generator(self, mock_logger): + response_data = [ + [ + { + "request_id": "test-request-id-0", + "outputs": {"token_ids": [1], "text": "a", "top_logprobs": None}, + "metrics": {"first_token_time": 0.1, "inference_start_time": 0.1}, + "finished": False, + }, + { + "request_id": "test-request-id-0", + "outputs": {"token_ids": [2], "text": "b", "top_logprobs": None}, + "metrics": {"arrival_time": 0.2, "first_token_time": None}, + "finished": False, + }, + ], + [ + { + "request_id": "test-request-id-0", + "outputs": {"token_ids": [7], "text": "g", "top_logprobs": None}, + "metrics": {"arrival_time": 0.7, "first_token_time": None, "request_start_time": 0.1}, + "finished": True, + } + ], + ] + + mock_response_queue = AsyncMock() + mock_response_queue.get.side_effect = response_data + + mock_dealer = Mock() + mock_dealer.write = Mock() + + # Mock the connection manager call + self.engine_client.connection_manager.get_connection = AsyncMock( + return_value=(mock_dealer, mock_response_queue) + ) + + request = CompletionRequest(model="test-model", prompt="Hello", stream=True, max_streaming_response_tokens=3) + + generator = self.completion_serving.completion_stream_generator( + request=request, + num_choices=1, + request_id="test-request-id", + model_name="test-model", + created_time=11, + prompt_batched_token_ids=[[1, 2, 3]], + text_after_process_list=["Hello"], + ) + + chunks = [] + async for chunk in generator: + chunks.append(chunk) + + self.assertGreater(len(chunks), 0, "No chucks!") + + parsed_chunks = [] + for i, chunk_str in enumerate(chunks): + if chunk_str.startswith("data: ") and chunk_str.endswith("\n\n"): + json_part = chunk_str[6:-2] + if json_part == "[DONE]": + break + try: + chunk_dict = json.loads(json_part) + parsed_chunks.append(chunk_dict) + except json.JSONDecodeError as e: + self.fail(f"Cannot parser {i+1} chunck, JSON: {e}\n origin string: {repr(chunk_str)}") + else: + self.fail(f"{i+1} chunk is unexcepted 'data: JSON\\n\\n': {repr(chunk_str)}") + self.assertEqual(len(parsed_chunks), 1) + for chunk_dict in parsed_chunks: + choices_list = chunk_dict["choices"] + self.assertEqual(len(choices_list), 3, f"Chunk {chunk_dict} should has three choices") + self.assertEqual( + choices_list[-1].get("finish_reason"), "stop", f"Chunk {chunk_dict} should has stop reason" + ) + + found_done = any("[DONE]" in chunk for chunk in chunks) + self.assertTrue(found_done, "Not Receive '[DONE]'") + + +if __name__ == "__main__": + unittest.main()