From 97189079b9624861cc0983fd948d964ca948e7c9 Mon Sep 17 00:00:00 2001 From: kxz2002 <115912648+kxz2002@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:01:33 +0800 Subject: [PATCH] [BugFix] unify max_tokens (#4968) * unify max tokens * modify and add unit test * modify and add unit test * modify and add unit tests --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> --- fastdeploy/entrypoints/openai/serving_chat.py | 18 +- .../entrypoints/openai/serving_completion.py | 20 +- .../openai/test_completion_echo.py | 4 + .../entrypoints/openai/test_finish_reason.py | 648 ++++++++++++++++++ .../openai/test_max_streaming_tokens.py | 7 + .../openai/test_serving_completion.py | 1 + tests/utils/test_custom_chat_template.py | 4 +- 7 files changed, 690 insertions(+), 12 deletions(-) create mode 100644 tests/entrypoints/openai/test_finish_reason.py diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 25bd6661c..208461312 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -126,6 +126,7 @@ class OpenAIServingChat: request_id = f"chatcmpl-{uuid.uuid4()}" api_server_logger.info(f"create chat completion request: {request_id}") prompt_tokens = None + max_tokens = None try: current_req_dict = request.to_dict_for_infer(f"{request_id}_0") if "chat_template" not in current_req_dict: @@ -134,6 +135,7 @@ class OpenAIServingChat: # preprocess the req_dict prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict) prompt_tokens = current_req_dict.get("prompt_tokens") + max_tokens = current_req_dict.get("max_tokens") if isinstance(prompt_token_ids, np.ndarray): prompt_token_ids = prompt_token_ids.tolist() except ParameterError as e: @@ -151,12 +153,12 @@ class OpenAIServingChat: if request.stream: return self.chat_completion_stream_generator( - request, request_id, request.model, prompt_token_ids, prompt_tokens + request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens ) else: try: return await self.chat_completion_full_generator( - request, request_id, request.model, prompt_token_ids, prompt_tokens + request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens ) except Exception as e: error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}" @@ -184,6 +186,7 @@ class OpenAIServingChat: model_name: str, prompt_token_ids: list(), prompt_tokens: str, + max_tokens: int, ): """ Streaming chat completion generator. @@ -382,9 +385,7 @@ class OpenAIServingChat: work_process_metrics.e2e_request_latency.observe( time.time() - res["metrics"]["request_start_time"] ) - has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None - max_tokens = request.max_completion_tokens or request.max_tokens - if has_no_token_limit or previous_num_tokens[idx] != max_tokens: + if previous_num_tokens[idx] != max_tokens: choice.finish_reason = "stop" if tool_called[idx]: choice.finish_reason = "tool_calls" @@ -461,6 +462,7 @@ class OpenAIServingChat: model_name: str, prompt_token_ids: list(), prompt_tokens: str, + max_tokens: int, ): """ Full chat completion generator. @@ -570,6 +572,7 @@ class OpenAIServingChat: num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, response_processor=response_processor, + max_tokens=max_tokens, ) choices.append(choice) finally: @@ -620,6 +623,7 @@ class OpenAIServingChat: num_image_tokens: list, logprob_contents: list, response_processor: ChatResponseProcessor, + max_tokens: int, ) -> ChatCompletionResponseChoice: idx = int(data["request_id"].split("_")[-1]) output = data["outputs"] @@ -646,15 +650,13 @@ class OpenAIServingChat: if logprob_contents[idx]: logprobs_full_res = LogProbs(content=logprob_contents[idx]) - has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None - max_tokens = request.max_completion_tokens or request.max_tokens num_cached_tokens[idx] = data.get("num_cached_tokens", 0) num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0) num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0) num_image_tokens[idx] = output.get("num_image_tokens", 0) finish_reason = "stop" - if has_no_token_limit or previous_num_tokens != max_tokens: + if previous_num_tokens != max_tokens: finish_reason = "stop" if output.get("tool_call"): finish_reason = "tool_calls" diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 1f3a064bf..9bf242cd0 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -141,6 +141,7 @@ class OpenAIServingCompletion: api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}") prompt_batched_token_ids = [] prompt_tokens_list = [] + max_tokens_list = [] try: if self.max_waiting_time < 0: await self.engine_client.semaphore.acquire() @@ -167,6 +168,7 @@ class OpenAIServingCompletion: prompt_token_ids = prompt_token_ids.tolist() prompt_tokens_list.append(current_req_dict.get("prompt_tokens")) prompt_batched_token_ids.append(prompt_token_ids) + max_tokens_list.append(current_req_dict.get("max_tokens")) del current_req_dict except ParameterError as e: api_server_logger.error(f"OpenAIServingCompletion format error: {e}, {e.message}") @@ -191,6 +193,7 @@ class OpenAIServingCompletion: model_name=request.model, prompt_batched_token_ids=prompt_batched_token_ids, prompt_tokens_list=prompt_tokens_list, + max_tokens_list=max_tokens_list, ) else: try: @@ -202,6 +205,7 @@ class OpenAIServingCompletion: model_name=request.model, prompt_batched_token_ids=prompt_batched_token_ids, prompt_tokens_list=prompt_tokens_list, + max_tokens_list=max_tokens_list, ) except Exception as e: error_msg = ( @@ -224,6 +228,7 @@ class OpenAIServingCompletion: model_name: str, prompt_batched_token_ids: list(), prompt_tokens_list: list(), + max_tokens_list: list(), ): """ Process the full completion request with multiple choices. @@ -312,6 +317,7 @@ class OpenAIServingCompletion: prompt_batched_token_ids=prompt_batched_token_ids, completion_batched_token_ids=completion_batched_token_ids, prompt_tokens_list=prompt_tokens_list, + max_tokens_list=max_tokens_list, ) api_server_logger.info(f"Completion response: {res.model_dump_json()}") return res @@ -365,6 +371,7 @@ class OpenAIServingCompletion: model_name: str, prompt_batched_token_ids: list(), prompt_tokens_list: list(), + max_tokens_list: list(), ): """ Process the stream completion request. @@ -501,7 +508,10 @@ class OpenAIServingCompletion: if res["finished"]: choices[-1].finish_reason = self.calc_finish_reason( - request.max_tokens, output_tokens[idx], output, tool_called[idx] + max_tokens_list[idx // (1 if request.n is None else request.n)], + output_tokens[idx], + output, + tool_called[idx], ) send_idx = output.get("send_idx") @@ -571,6 +581,7 @@ class OpenAIServingCompletion: prompt_batched_token_ids: list(), completion_batched_token_ids: list(), prompt_tokens_list: list(), + max_tokens_list: list(), ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -607,7 +618,12 @@ class OpenAIServingCompletion: else: token_ids = output["token_ids"] output_text = output["text"] - finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False) + finish_reason = self.calc_finish_reason( + max_tokens_list[idx // (1 if request.n is None else request.n)], + final_res["output_token_ids"], + output, + False, + ) choice_data = CompletionResponseChoice( token_ids=token_ids, diff --git a/tests/entrypoints/openai/test_completion_echo.py b/tests/entrypoints/openai/test_completion_echo.py index a2d4313db..679f6d8ec 100644 --- a/tests/entrypoints/openai/test_completion_echo.py +++ b/tests/entrypoints/openai/test_completion_echo.py @@ -58,6 +58,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase): prompt_batched_token_ids=[[1, 2]], completion_batched_token_ids=[[3, 4, 5]], prompt_tokens_list=["test prompt"], + max_tokens_list=[100], ) self.assertEqual(response.choices[0].text, "test prompt generated text") @@ -91,6 +92,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase): prompt_batched_token_ids=[[1, 2]], completion_batched_token_ids=[[3, 4, 5]], prompt_tokens_list=["test prompt"], + max_tokens_list=[100], ) self.assertEqual(response.choices[0].text, "decoded_[1, 2, 3] generated text") @@ -124,6 +126,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase): prompt_batched_token_ids=[[1], [2]], completion_batched_token_ids=[[1, 2], [3, 4]], prompt_tokens_list=["prompt1", "prompt2"], + max_tokens_list=[100, 100], ) self.assertEqual(len(response.choices), 2) @@ -160,6 +163,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase): prompt_batched_token_ids=[[1], [2]], completion_batched_token_ids=[[1, 2], [3, 4]], prompt_tokens_list=["prompt1", "prompt2"], + max_tokens_list=[100, 100], ) self.assertEqual(len(response.choices), 2) diff --git a/tests/entrypoints/openai/test_finish_reason.py b/tests/entrypoints/openai/test_finish_reason.py new file mode 100644 index 000000000..4bdb3feef --- /dev/null +++ b/tests/entrypoints/openai/test_finish_reason.py @@ -0,0 +1,648 @@ +import json +from typing import Any, Dict, List +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import numpy as np + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + DeltaMessage, + UsageInfo, +) +from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat +from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion +from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor +from fastdeploy.utils import data_processor_logger + + +class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None): + self.multi_modal_processor = Ernie4_5_VLProcessor("model_path") + self.multi_modal_processor.tokenizer = MagicMock() + self.multi_modal_processor.tokenizer.eos_token_id = 102 + self.multi_modal_processor.tokenizer.pad_token_id = 0 + self.multi_modal_processor.eos_token_ids = [102] + self.multi_modal_processor.eos_token_id_len = 1 + self.multi_modal_processor.generation_config = MagicMock() + self.multi_modal_processor.decode_status = {} + self.multi_modal_processor.tool_parser_dict = {} + self.multi_modal_processor.ernie4_5_processor = MagicMock() + self.multi_modal_processor.ernie4_5_processor.request2ids.return_value = { + "input_ids": np.array([101, 9012, 3456, 102]) + } + self.multi_modal_processor.ernie4_5_processor.text2ids.return_value = { + "input_ids": np.array([101, 1234, 5678, 102]) + } + self.multi_modal_processor._apply_default_parameters = lambda x: x + self.multi_modal_processor.update_stop_seq = Mock(return_value=([], [])) + self.multi_modal_processor.update_bad_words = Mock(return_value=[]) + self.multi_modal_processor._check_mm_limits = Mock() + self.multi_modal_processor.append_completion_tokens = Mock() + self.multi_modal_processor.pack_outputs = lambda x: x + + self.engine_client = Mock() + self.engine_client.connection_initialized = False + self.engine_client.connection_manager = AsyncMock() + self.engine_client.semaphore = Mock() + self.engine_client.semaphore.acquire = AsyncMock() + self.engine_client.semaphore.release = Mock() + self.engine_client.is_master = True + self.engine_client.check_model_weight_status = Mock(return_value=False) + self.engine_client.enable_mm = True + self.engine_client.enable_prefix_caching = False + self.engine_client.max_model_len = 20 + self.engine_client.data_processor = self.multi_modal_processor + + async def mock_add_data(current_req_dict): + if current_req_dict.get("max_tokens") is None: + current_req_dict["max_tokens"] = self.engine_client.max_model_len - 1 + current_req_dict["max_tokens"] = min( + self.engine_client.max_model_len - 4, max(0, current_req_dict.get("max_tokens")) + ) + + self.engine_client.add_requests = AsyncMock(side_effect=mock_add_data) + + self.chat_serving = OpenAIServingChat( + engine_client=self.engine_client, + models=None, + pid=123, + ips=None, + max_waiting_time=30, + chat_template="default", + enable_mm_output=True, + tokenizer_base_url=None, + ) + self.completion_serving = OpenAIServingCompletion( + engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30 + ) + + def _generate_inference_response( + self, request_id: str, output_token_num: int, tool_call: Any = None + ) -> List[Dict]: + outputs = { + "text": "这是一张风景图"[:output_token_num], + "token_ids": list(range(output_token_num)), + "reasoning_content": "推理过程", + "num_image_tokens": 0, + "num_cached_tokens": 0, + "top_logprobs": None, + "draft_top_logprobs": None, + "tool_call": None, + } + + if tool_call: + outputs["tool_call"] = [ + {"index": 0, "type": "function", "function": {"name": tool_call["name"], "arguments": json.dumps({})}} + ] + + return [ + { + "request_id": request_id, + "outputs": outputs, + "metrics": {"request_start_time": 0.1}, + "finished": True, + "error_msg": None, + "output_token_ids": output_token_num, + } + ] + + def _generate_stream_inference_response( + self, request_id: str, total_token_num: int, tool_call: Any = None + ) -> List[Dict]: + stream_responses = [] + for i in range(total_token_num): + metrics = {} + if i == 0: + metrics["first_token_time"] = 0.1 + metrics["inference_start_time"] = 0.1 + else: + metrics["arrival_time"] = 0.1 * (i + 1) + metrics["first_token_time"] = None + + if i == total_token_num - 1: + metrics["request_start_time"] = 0.1 + + outputs = { + "text": chr(ord("a") + i), + "token_ids": [i + 1], + "top_logprobs": None, + "draft_top_logprobs": None, + "reasoning_token_num": 0, + } + + if tool_call and isinstance(tool_call, dict) and i == total_token_num - 2: + delta_msg = DeltaMessage( + content="", + reasoning_content="", + tool_calls=[ + { + "index": 0, + "type": "function", + "function": {"name": tool_call["name"], "arguments": json.dumps({})}, + } + ], + prompt_token_ids=None, + completion_token_ids=None, + ) + outputs["delta_message"] = delta_msg + + frame = [ + { + "request_id": f"{request_id}_0", + "error_code": 200, + "outputs": outputs, + "metrics": metrics, + "finished": (i == total_token_num - 1), + "error_msg": None, + } + ] + stream_responses.append(frame) + return stream_responses + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor") + @patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger") + async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class, mock_api_logger): + test_cases = [ + { + "name": "用户传max_tokens=5,生成数=5→length", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + max_tokens=5, + return_token_ids=True, + ), + "output_token_num": 5, + "tool_call": [], + "expected_finish_reason": "length", + }, + { + "name": "用户未传max_tokens,生成数=10→stop", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + return_token_ids=True, + ), + "output_token_num": 10, + "tool_call": [], + "expected_finish_reason": "stop", + }, + { + "name": "用户未传max_tokens,生成数=16→length", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + return_token_ids=True, + ), + "output_token_num": 16, + "tool_call": [], + "expected_finish_reason": "length", + }, + { + "name": "用户传max_tokens,生成数=10→stop", + "request": ChatCompletionRequest( + model="ernie4.5-vl", + messages=[{"role": "user", "content": "描述这张图片"}], + stream=False, + max_tokens=50, + return_token_ids=True, + ), + "output_token_num": 10, + "tool_call": [], + "expected_finish_reason": "stop", + }, + { + "name": "生成数 0: + finish_reason = chunk_dict["choices"][0].get("finish_reason") + if finish_reason: + final_finish_reason = finish_reason + break + except (json.JSONDecodeError, KeyError, IndexError): + continue + + self.assertEqual(final_finish_reason, case["expected_finish_reason"]) + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_completion_stream_max_tokens(self, mock_api_logger, mock_data_logger): + test_cases = [ + { + "name": "流式-生成数=7(等于max_tokens)→length", + "request": CompletionRequest( + model="ernie4.5-vl", + prompt=["描述这张图片:xxx"], + stream=True, + max_tokens=7, + return_token_ids=True, + ), + "total_token_num": 7, + "expected_finish_reason": "length", + }, + { + "name": "流式-生成数=9(小于max_tokens)→stop", + "request": CompletionRequest( + model="ernie4.5-vl", + prompt=["描述这张图片:xxx"], + stream=True, + max_tokens=15, + return_token_ids=True, + ), + "total_token_num": 9, + "expected_finish_reason": "stop", + }, + ] + + mock_dealer = Mock() + self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock())) + + for case in test_cases: + with self.subTest(case=case["name"]): + request_dict = { + "prompt": case["request"].prompt, + "multimodal_data": {"image": ["xxx"]}, + "request_id": "test_completion_stream_0", + "max_tokens": case["request"].max_tokens, + } + await self.engine_client.add_requests(request_dict) + processed_req = self.multi_modal_processor.process_request_dict( + request_dict, self.engine_client.max_model_len + ) + self.engine_client.data_processor.process_response_dict = ( + lambda data, stream, include_stop_str_in_output: data + ) + + mock_response_queue = AsyncMock() + stream_responses = self._generate_stream_inference_response( + request_id="test_completion_stream_0", total_token_num=case["total_token_num"] + ) + mock_response_queue.get.side_effect = stream_responses + self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue) + + generator = self.completion_serving.completion_stream_generator( + request=case["request"], + num_choices=1, + created_time=0, + request_id="test_completion_stream", + model_name="ernie4.5-vl", + prompt_batched_token_ids=[processed_req["prompt_token_ids"]], + prompt_tokens_list=case["request"].prompt, + max_tokens_list=[processed_req["max_tokens"]], + ) + + final_finish_reason = None + chunks = [] + async for chunk in generator: + chunks.append(chunk) + if "[DONE]" in chunk: + break + + for chunk_str in chunks: + if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str: + try: + json_part = chunk_str.strip().lstrip("data: ") + chunk_dict = json.loads(json_part) + if chunk_dict["choices"][0].get("finish_reason"): + final_finish_reason = chunk_dict["choices"][0]["finish_reason"] + break + except (json.JSONDecodeError, KeyError, IndexError): + continue + + self.assertEqual(final_finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败") + + @patch.object(data_processor_logger, "info") + @patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger") + async def test_completion_create_max_tokens_list_basic(self, mock_api_logger, mock_data_logger): + test_cases = [ + { + "name": "单prompt → max_tokens_list长度1", + "request": CompletionRequest( + request_id="test_single_prompt", + model="ernie4.5-vl", + prompt="请介绍人工智能的应用", + stream=False, + max_tokens=10, + ), + "mock_max_tokens": 8, + "expected_max_tokens_list_len": 1, + "expected_max_tokens_list": [8], + }, + { + "name": "多prompt → max_tokens_list长度2", + "request": CompletionRequest( + request_id="test_multi_prompt", + model="ernie4.5-vl", + prompt=["请介绍Python语言", "请说明机器学习的步骤"], + stream=False, + max_tokens=15, + ), + "mock_max_tokens": [12, 10], + "expected_max_tokens_list_len": 2, + "expected_max_tokens_list": [12, 10], + }, + ] + + async def mock_format_and_add_data(current_req_dict): + req_idx = int(current_req_dict["request_id"].split("_")[-1]) + if isinstance(case["mock_max_tokens"], list): + current_req_dict["max_tokens"] = case["mock_max_tokens"][req_idx] + else: + current_req_dict["max_tokens"] = case["mock_max_tokens"] + return [101, 102, 103, 104] + + self.engine_client.format_and_add_data = AsyncMock(side_effect=mock_format_and_add_data) + + async def intercept_generator(**kwargs): + actual_max_tokens_list = kwargs["max_tokens_list"] + self.assertEqual( + len(actual_max_tokens_list), + case["expected_max_tokens_list_len"], + f"列表长度不匹配:实际{len(actual_max_tokens_list)},预期{case['expected_max_tokens_list_len']}", + ) + self.assertEqual( + actual_max_tokens_list, + case["expected_max_tokens_list"], + f"列表元素不匹配:实际{actual_max_tokens_list},预期{case['expected_max_tokens_list']}", + ) + return CompletionResponse( + id=kwargs["request_id"], + object="text_completion", + created=kwargs["created_time"], + model=kwargs["model_name"], + choices=[], + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + self.completion_serving.completion_full_generator = AsyncMock(side_effect=intercept_generator) + + for case in test_cases: + with self.subTest(case=case["name"]): + result = await self.completion_serving.create_completion(request=case["request"]) + self.assertIsInstance(result, CompletionResponse) diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index 698ff368b..1e728aa3b 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -169,6 +169,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): model_name="test-model", prompt_token_ids=[1, 2, 3], prompt_tokens="Hello", + max_tokens=10, ) chunks = [] @@ -251,6 +252,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): created_time=11, prompt_batched_token_ids=[[1, 2, 3]], prompt_tokens_list=["Hello"], + max_tokens_list=[100], ) chunks = [] @@ -352,6 +354,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): model_name=model_name, prompt_batched_token_ids=prompt_batched_token_ids, prompt_tokens_list=prompt_tokens_list, + max_tokens_list=[100, 100], ) self.assertEqual(actual_response, expected_completion_response) @@ -449,6 +452,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): num_input_image_tokens = [0, 0] num_input_video_tokens = [0, 0] num_image_tokens = [0, 0] + max_tokens_list = [10, 1] for idx, case in enumerate(test_cases): actual_choice = await self.chat_serving._create_chat_completion_choice( @@ -464,6 +468,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, response_processor=mock_response_processor, + max_tokens=max_tokens_list[idx], ) expected = case["expected"] @@ -554,6 +559,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): model_name="test-model", prompt_token_ids=[10, 20, 30], prompt_tokens="Hello", + max_tokens=10, ) chunks = [] @@ -661,6 +667,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase): model_name="test-model", prompt_batched_token_ids=[[10, 20, 30]], prompt_tokens_list=["Hello"], + max_tokens_list=[100], ) chunks = [] diff --git a/tests/entrypoints/openai/test_serving_completion.py b/tests/entrypoints/openai/test_serving_completion.py index c2a36d185..d1aded4f3 100644 --- a/tests/entrypoints/openai/test_serving_completion.py +++ b/tests/entrypoints/openai/test_serving_completion.py @@ -159,6 +159,7 @@ class TestOpenAIServingCompletion(unittest.TestCase): prompt_batched_token_ids=prompt_batched_token_ids, completion_batched_token_ids=completion_batched_token_ids, prompt_tokens_list=["1", "1"], + max_tokens_list=[10, 10], ) assert completion_response.id == request_id diff --git a/tests/utils/test_custom_chat_template.py b/tests/utils/test_custom_chat_template.py index 8079ba203..bf53bc93b 100644 --- a/tests/utils/test_custom_chat_template.py +++ b/tests/utils/test_custom_chat_template.py @@ -78,7 +78,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): ) async def mock_chat_completion_full_generator( - request, request_id, model_name, prompt_token_ids, prompt_tokens + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list ): return prompt_token_ids @@ -106,7 +106,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase): ) async def mock_chat_completion_full_generator( - request, request_id, model_name, prompt_token_ids, prompt_tokens + request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list ): return prompt_token_ids