[BugFix] unify max_tokens (#4968)

* unify max tokens

* modify and add unit test

* modify and add unit test

* modify and add unit tests

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
kxz2002
2025-11-18 20:01:33 +08:00
committed by GitHub
parent 3d7f1a843e
commit 97189079b9
7 changed files with 690 additions and 12 deletions

View File

@@ -126,6 +126,7 @@ class OpenAIServingChat:
request_id = f"chatcmpl-{uuid.uuid4()}"
api_server_logger.info(f"create chat completion request: {request_id}")
prompt_tokens = None
max_tokens = None
try:
current_req_dict = request.to_dict_for_infer(f"{request_id}_0")
if "chat_template" not in current_req_dict:
@@ -134,6 +135,7 @@ class OpenAIServingChat:
# preprocess the req_dict
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict)
prompt_tokens = current_req_dict.get("prompt_tokens")
max_tokens = current_req_dict.get("max_tokens")
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
except ParameterError as e:
@@ -151,12 +153,12 @@ class OpenAIServingChat:
if request.stream:
return self.chat_completion_stream_generator(
request, request_id, request.model, prompt_token_ids, prompt_tokens
request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens
)
else:
try:
return await self.chat_completion_full_generator(
request, request_id, request.model, prompt_token_ids, prompt_tokens
request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens
)
except Exception as e:
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
@@ -184,6 +186,7 @@ class OpenAIServingChat:
model_name: str,
prompt_token_ids: list(),
prompt_tokens: str,
max_tokens: int,
):
"""
Streaming chat completion generator.
@@ -382,9 +385,7 @@ class OpenAIServingChat:
work_process_metrics.e2e_request_latency.observe(
time.time() - res["metrics"]["request_start_time"]
)
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens[idx] != max_tokens:
if previous_num_tokens[idx] != max_tokens:
choice.finish_reason = "stop"
if tool_called[idx]:
choice.finish_reason = "tool_calls"
@@ -461,6 +462,7 @@ class OpenAIServingChat:
model_name: str,
prompt_token_ids: list(),
prompt_tokens: str,
max_tokens: int,
):
"""
Full chat completion generator.
@@ -570,6 +572,7 @@ class OpenAIServingChat:
num_image_tokens=num_image_tokens,
logprob_contents=logprob_contents,
response_processor=response_processor,
max_tokens=max_tokens,
)
choices.append(choice)
finally:
@@ -620,6 +623,7 @@ class OpenAIServingChat:
num_image_tokens: list,
logprob_contents: list,
response_processor: ChatResponseProcessor,
max_tokens: int,
) -> ChatCompletionResponseChoice:
idx = int(data["request_id"].split("_")[-1])
output = data["outputs"]
@@ -646,15 +650,13 @@ class OpenAIServingChat:
if logprob_contents[idx]:
logprobs_full_res = LogProbs(content=logprob_contents[idx])
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
num_image_tokens[idx] = output.get("num_image_tokens", 0)
finish_reason = "stop"
if has_no_token_limit or previous_num_tokens != max_tokens:
if previous_num_tokens != max_tokens:
finish_reason = "stop"
if output.get("tool_call"):
finish_reason = "tool_calls"

View File

@@ -141,6 +141,7 @@ class OpenAIServingCompletion:
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
prompt_batched_token_ids = []
prompt_tokens_list = []
max_tokens_list = []
try:
if self.max_waiting_time < 0:
await self.engine_client.semaphore.acquire()
@@ -167,6 +168,7 @@ class OpenAIServingCompletion:
prompt_token_ids = prompt_token_ids.tolist()
prompt_tokens_list.append(current_req_dict.get("prompt_tokens"))
prompt_batched_token_ids.append(prompt_token_ids)
max_tokens_list.append(current_req_dict.get("max_tokens"))
del current_req_dict
except ParameterError as e:
api_server_logger.error(f"OpenAIServingCompletion format error: {e}, {e.message}")
@@ -191,6 +193,7 @@ class OpenAIServingCompletion:
model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids,
prompt_tokens_list=prompt_tokens_list,
max_tokens_list=max_tokens_list,
)
else:
try:
@@ -202,6 +205,7 @@ class OpenAIServingCompletion:
model_name=request.model,
prompt_batched_token_ids=prompt_batched_token_ids,
prompt_tokens_list=prompt_tokens_list,
max_tokens_list=max_tokens_list,
)
except Exception as e:
error_msg = (
@@ -224,6 +228,7 @@ class OpenAIServingCompletion:
model_name: str,
prompt_batched_token_ids: list(),
prompt_tokens_list: list(),
max_tokens_list: list(),
):
"""
Process the full completion request with multiple choices.
@@ -312,6 +317,7 @@ class OpenAIServingCompletion:
prompt_batched_token_ids=prompt_batched_token_ids,
completion_batched_token_ids=completion_batched_token_ids,
prompt_tokens_list=prompt_tokens_list,
max_tokens_list=max_tokens_list,
)
api_server_logger.info(f"Completion response: {res.model_dump_json()}")
return res
@@ -365,6 +371,7 @@ class OpenAIServingCompletion:
model_name: str,
prompt_batched_token_ids: list(),
prompt_tokens_list: list(),
max_tokens_list: list(),
):
"""
Process the stream completion request.
@@ -501,7 +508,10 @@ class OpenAIServingCompletion:
if res["finished"]:
choices[-1].finish_reason = self.calc_finish_reason(
request.max_tokens, output_tokens[idx], output, tool_called[idx]
max_tokens_list[idx // (1 if request.n is None else request.n)],
output_tokens[idx],
output,
tool_called[idx],
)
send_idx = output.get("send_idx")
@@ -571,6 +581,7 @@ class OpenAIServingCompletion:
prompt_batched_token_ids: list(),
completion_batched_token_ids: list(),
prompt_tokens_list: list(),
max_tokens_list: list(),
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@@ -607,7 +618,12 @@ class OpenAIServingCompletion:
else:
token_ids = output["token_ids"]
output_text = output["text"]
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False)
finish_reason = self.calc_finish_reason(
max_tokens_list[idx // (1 if request.n is None else request.n)],
final_res["output_token_ids"],
output,
False,
)
choice_data = CompletionResponseChoice(
token_ids=token_ids,

View File

@@ -58,6 +58,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
prompt_batched_token_ids=[[1, 2]],
completion_batched_token_ids=[[3, 4, 5]],
prompt_tokens_list=["test prompt"],
max_tokens_list=[100],
)
self.assertEqual(response.choices[0].text, "test prompt generated text")
@@ -91,6 +92,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
prompt_batched_token_ids=[[1, 2]],
completion_batched_token_ids=[[3, 4, 5]],
prompt_tokens_list=["test prompt"],
max_tokens_list=[100],
)
self.assertEqual(response.choices[0].text, "decoded_[1, 2, 3] generated text")
@@ -124,6 +126,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
prompt_batched_token_ids=[[1], [2]],
completion_batched_token_ids=[[1, 2], [3, 4]],
prompt_tokens_list=["prompt1", "prompt2"],
max_tokens_list=[100, 100],
)
self.assertEqual(len(response.choices), 2)
@@ -160,6 +163,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
prompt_batched_token_ids=[[1], [2]],
completion_batched_token_ids=[[1, 2], [3, 4]],
prompt_tokens_list=["prompt1", "prompt2"],
max_tokens_list=[100, 100],
)
self.assertEqual(len(response.choices), 2)

View File

@@ -0,0 +1,648 @@
import json
from typing import Any, Dict, List
from unittest import IsolatedAsyncioTestCase
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import numpy as np
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
DeltaMessage,
UsageInfo,
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
from fastdeploy.utils import data_processor_logger
class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
async def asyncSetUp(self):
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
self.multi_modal_processor = Ernie4_5_VLProcessor("model_path")
self.multi_modal_processor.tokenizer = MagicMock()
self.multi_modal_processor.tokenizer.eos_token_id = 102
self.multi_modal_processor.tokenizer.pad_token_id = 0
self.multi_modal_processor.eos_token_ids = [102]
self.multi_modal_processor.eos_token_id_len = 1
self.multi_modal_processor.generation_config = MagicMock()
self.multi_modal_processor.decode_status = {}
self.multi_modal_processor.tool_parser_dict = {}
self.multi_modal_processor.ernie4_5_processor = MagicMock()
self.multi_modal_processor.ernie4_5_processor.request2ids.return_value = {
"input_ids": np.array([101, 9012, 3456, 102])
}
self.multi_modal_processor.ernie4_5_processor.text2ids.return_value = {
"input_ids": np.array([101, 1234, 5678, 102])
}
self.multi_modal_processor._apply_default_parameters = lambda x: x
self.multi_modal_processor.update_stop_seq = Mock(return_value=([], []))
self.multi_modal_processor.update_bad_words = Mock(return_value=[])
self.multi_modal_processor._check_mm_limits = Mock()
self.multi_modal_processor.append_completion_tokens = Mock()
self.multi_modal_processor.pack_outputs = lambda x: x
self.engine_client = Mock()
self.engine_client.connection_initialized = False
self.engine_client.connection_manager = AsyncMock()
self.engine_client.semaphore = Mock()
self.engine_client.semaphore.acquire = AsyncMock()
self.engine_client.semaphore.release = Mock()
self.engine_client.is_master = True
self.engine_client.check_model_weight_status = Mock(return_value=False)
self.engine_client.enable_mm = True
self.engine_client.enable_prefix_caching = False
self.engine_client.max_model_len = 20
self.engine_client.data_processor = self.multi_modal_processor
async def mock_add_data(current_req_dict):
if current_req_dict.get("max_tokens") is None:
current_req_dict["max_tokens"] = self.engine_client.max_model_len - 1
current_req_dict["max_tokens"] = min(
self.engine_client.max_model_len - 4, max(0, current_req_dict.get("max_tokens"))
)
self.engine_client.add_requests = AsyncMock(side_effect=mock_add_data)
self.chat_serving = OpenAIServingChat(
engine_client=self.engine_client,
models=None,
pid=123,
ips=None,
max_waiting_time=30,
chat_template="default",
enable_mm_output=True,
tokenizer_base_url=None,
)
self.completion_serving = OpenAIServingCompletion(
engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30
)
def _generate_inference_response(
self, request_id: str, output_token_num: int, tool_call: Any = None
) -> List[Dict]:
outputs = {
"text": "这是一张风景图"[:output_token_num],
"token_ids": list(range(output_token_num)),
"reasoning_content": "推理过程",
"num_image_tokens": 0,
"num_cached_tokens": 0,
"top_logprobs": None,
"draft_top_logprobs": None,
"tool_call": None,
}
if tool_call:
outputs["tool_call"] = [
{"index": 0, "type": "function", "function": {"name": tool_call["name"], "arguments": json.dumps({})}}
]
return [
{
"request_id": request_id,
"outputs": outputs,
"metrics": {"request_start_time": 0.1},
"finished": True,
"error_msg": None,
"output_token_ids": output_token_num,
}
]
def _generate_stream_inference_response(
self, request_id: str, total_token_num: int, tool_call: Any = None
) -> List[Dict]:
stream_responses = []
for i in range(total_token_num):
metrics = {}
if i == 0:
metrics["first_token_time"] = 0.1
metrics["inference_start_time"] = 0.1
else:
metrics["arrival_time"] = 0.1 * (i + 1)
metrics["first_token_time"] = None
if i == total_token_num - 1:
metrics["request_start_time"] = 0.1
outputs = {
"text": chr(ord("a") + i),
"token_ids": [i + 1],
"top_logprobs": None,
"draft_top_logprobs": None,
"reasoning_token_num": 0,
}
if tool_call and isinstance(tool_call, dict) and i == total_token_num - 2:
delta_msg = DeltaMessage(
content="",
reasoning_content="",
tool_calls=[
{
"index": 0,
"type": "function",
"function": {"name": tool_call["name"], "arguments": json.dumps({})},
}
],
prompt_token_ids=None,
completion_token_ids=None,
)
outputs["delta_message"] = delta_msg
frame = [
{
"request_id": f"{request_id}_0",
"error_code": 200,
"outputs": outputs,
"metrics": metrics,
"finished": (i == total_token_num - 1),
"error_msg": None,
}
]
stream_responses.append(frame)
return stream_responses
@patch.object(data_processor_logger, "info")
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class, mock_api_logger):
test_cases = [
{
"name": "用户传max_tokens=5生成数=5→length",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=False,
max_tokens=5,
return_token_ids=True,
),
"output_token_num": 5,
"tool_call": [],
"expected_finish_reason": "length",
},
{
"name": "用户未传max_tokens生成数=10→stop",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=False,
return_token_ids=True,
),
"output_token_num": 10,
"tool_call": [],
"expected_finish_reason": "stop",
},
{
"name": "用户未传max_tokens生成数=16→length",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=False,
return_token_ids=True,
),
"output_token_num": 16,
"tool_call": [],
"expected_finish_reason": "length",
},
{
"name": "用户传max_tokens生成数=10→stop",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=False,
max_tokens=50,
return_token_ids=True,
),
"output_token_num": 10,
"tool_call": [],
"expected_finish_reason": "stop",
},
{
"name": "生成数<max_tokens触发tool_call→tool_calls",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=False,
max_tokens=10,
return_token_ids=True,
),
"output_token_num": 8,
"tool_call": {"name": "test_tool"},
"expected_finish_reason": "tool_calls",
},
]
mock_response_queue = AsyncMock()
mock_dealer = Mock()
self.engine_client.connection_manager.get_connection = AsyncMock(
return_value=(mock_dealer, mock_response_queue)
)
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = True
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_async
mock_processor_class.return_value = mock_processor_instance
for case in test_cases:
with self.subTest(case=case["name"]):
request_dict = {
"messages": case["request"].messages,
"chat_template": "default",
"request_id": "test_chat_0",
"max_tokens": case["request"].max_tokens,
}
await self.engine_client.add_requests(request_dict)
processed_req = self.multi_modal_processor.process_request_dict(
request_dict, self.engine_client.max_model_len
)
mock_response_queue.get.side_effect = self._generate_inference_response(
request_id="test_chat_0", output_token_num=case["output_token_num"], tool_call=case["tool_call"]
)
result = await self.chat_serving.chat_completion_full_generator(
request=case["request"],
request_id="test_chat",
model_name="ernie4.5-vl",
prompt_token_ids=processed_req["prompt_token_ids"],
prompt_tokens="描述这张图片",
max_tokens=processed_req["max_tokens"],
)
self.assertEqual(
result.choices[0].finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败"
)
@patch.object(data_processor_logger, "info")
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
async def test_completion_full_max_tokens(self, mock_api_logger, mock_data_logger):
test_cases = [
{
"name": "用户传max_tokens=6生成数=6→length",
"request": CompletionRequest(
request_id="test_completion",
model="ernie4.5-vl",
prompt="描述这张图片:<image>xxx</image>",
stream=False,
max_tokens=6,
return_token_ids=True,
),
"output_token_num": 6,
"expected_finish_reason": "length",
},
{
"name": "用户未传max_tokens生成数=12→stop",
"request": CompletionRequest(
request_id="test_completion",
model="ernie4.5-vl",
prompt="描述这张图片:<image>xxx</image>",
stream=False,
return_token_ids=True,
),
"output_token_num": 12,
"expected_finish_reason": "stop",
},
{
"name": "用户传max_tokens=20修正为16生成数=16→length",
"request": CompletionRequest(
request_id="test_completion",
model="ernie4.5-vl",
prompt="描述这张图片:<image>xxx</image>",
stream=False,
max_tokens=20,
return_token_ids=True,
),
"output_token_num": 16,
"expected_finish_reason": "length",
},
]
mock_dealer = Mock()
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
for case in test_cases:
with self.subTest(case=case["name"]):
request_dict = {
"prompt": case["request"].prompt,
"request_id": "test_completion",
"multimodal_data": {"image": ["xxx"]},
"max_tokens": case["request"].max_tokens,
}
await self.engine_client.add_requests(request_dict)
processed_req = self.multi_modal_processor.process_request_dict(
request_dict, self.engine_client.max_model_len
)
self.engine_client.data_processor.process_response_dict = (
lambda data, stream, include_stop_str_in_output: data
)
mock_response_queue = AsyncMock()
mock_response_queue.get.side_effect = lambda: [
{
"request_id": "test_completion_0",
"error_code": 200,
"outputs": {
"text": "这是一张风景图"[: case["output_token_num"]],
"token_ids": list(range(case["output_token_num"])),
"top_logprobs": None,
"draft_top_logprobs": None,
},
"metrics": {"request_start_time": 0.1},
"finished": True,
"error_msg": None,
"output_token_ids": case["output_token_num"],
}
]
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
result = await self.completion_serving.completion_full_generator(
request=case["request"],
num_choices=1,
request_id="test_completion",
created_time=1699999999,
model_name="ernie4.5-vl",
prompt_batched_token_ids=[processed_req["prompt_token_ids"]],
prompt_tokens_list=[case["request"].prompt],
max_tokens_list=[processed_req["max_tokens"]],
)
self.assertIsInstance(result, CompletionResponse)
self.assertEqual(result.choices[0].finish_reason, case["expected_finish_reason"])
@patch.object(data_processor_logger, "info")
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
async def test_chat_stream_max_tokens(self, mock_api_logger, mock_processor_class, mock_data_logger):
test_cases = [
{
"name": "流式-生成数=8等于max_tokens→length",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=True,
max_tokens=8,
return_token_ids=True,
),
"total_token_num": 8,
"tool_call": None,
"expected_finish_reason": "length",
},
{
"name": "流式-生成数=6小于max_tokens+tool_call→tool_calls",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=True,
max_tokens=10,
return_token_ids=True,
),
"total_token_num": 3,
"tool_call": {"name": "test_tool"},
"expected_finish_reason": "tool_calls",
},
{
"name": "流式-生成数=7小于max_tokens无tool_call→stop",
"request": ChatCompletionRequest(
model="ernie4.5-vl",
messages=[{"role": "user", "content": "描述这张图片"}],
stream=True,
max_tokens=10,
return_token_ids=True,
),
"total_token_num": 7,
"tool_call": None,
"expected_finish_reason": "stop",
},
]
mock_dealer = Mock()
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
mock_processor_instance = Mock()
mock_processor_instance.enable_multimodal_content.return_value = False
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
if isinstance(response, list):
for res in response:
yield res
else:
yield response
mock_processor_instance.process_response_chat = mock_process_response_chat_async
mock_processor_class.return_value = mock_processor_instance
for case in test_cases:
with self.subTest(case=case["name"]):
request_dict = {
"messages": case["request"].messages,
"chat_template": "default",
"request_id": "test_chat_stream_0",
"max_tokens": case["request"].max_tokens,
}
await self.engine_client.add_requests(request_dict)
processed_req = self.multi_modal_processor.process_request_dict(
request_dict, self.engine_client.max_model_len
)
self.engine_client.data_processor.process_response_dict = (
lambda data, stream, include_stop_str_in_output: data
)
mock_response_queue = AsyncMock()
stream_responses = self._generate_stream_inference_response(
request_id="test_chat_stream_0_0",
total_token_num=case["total_token_num"],
tool_call=case["tool_call"],
)
mock_response_queue.get.side_effect = stream_responses
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
generator = self.chat_serving.chat_completion_stream_generator(
request=case["request"],
request_id="test_chat_stream_0",
model_name="ernie4.5-vl",
prompt_token_ids=processed_req["prompt_token_ids"],
prompt_tokens="描述这张图片",
max_tokens=processed_req["max_tokens"],
)
final_finish_reason = None
chunks = []
async for chunk in generator:
chunks.append(chunk)
if "[DONE]" in chunk:
break
for chunk_str in chunks:
if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str:
try:
json_part = chunk_str.strip().lstrip("data: ").rstrip("\n\n")
chunk_dict = json.loads(json_part)
if chunk_dict.get("choices") and len(chunk_dict["choices"]) > 0:
finish_reason = chunk_dict["choices"][0].get("finish_reason")
if finish_reason:
final_finish_reason = finish_reason
break
except (json.JSONDecodeError, KeyError, IndexError):
continue
self.assertEqual(final_finish_reason, case["expected_finish_reason"])
@patch.object(data_processor_logger, "info")
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
async def test_completion_stream_max_tokens(self, mock_api_logger, mock_data_logger):
test_cases = [
{
"name": "流式-生成数=7等于max_tokens→length",
"request": CompletionRequest(
model="ernie4.5-vl",
prompt=["描述这张图片:<image>xxx</image>"],
stream=True,
max_tokens=7,
return_token_ids=True,
),
"total_token_num": 7,
"expected_finish_reason": "length",
},
{
"name": "流式-生成数=9小于max_tokens→stop",
"request": CompletionRequest(
model="ernie4.5-vl",
prompt=["描述这张图片:<image>xxx</image>"],
stream=True,
max_tokens=15,
return_token_ids=True,
),
"total_token_num": 9,
"expected_finish_reason": "stop",
},
]
mock_dealer = Mock()
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
for case in test_cases:
with self.subTest(case=case["name"]):
request_dict = {
"prompt": case["request"].prompt,
"multimodal_data": {"image": ["xxx"]},
"request_id": "test_completion_stream_0",
"max_tokens": case["request"].max_tokens,
}
await self.engine_client.add_requests(request_dict)
processed_req = self.multi_modal_processor.process_request_dict(
request_dict, self.engine_client.max_model_len
)
self.engine_client.data_processor.process_response_dict = (
lambda data, stream, include_stop_str_in_output: data
)
mock_response_queue = AsyncMock()
stream_responses = self._generate_stream_inference_response(
request_id="test_completion_stream_0", total_token_num=case["total_token_num"]
)
mock_response_queue.get.side_effect = stream_responses
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
generator = self.completion_serving.completion_stream_generator(
request=case["request"],
num_choices=1,
created_time=0,
request_id="test_completion_stream",
model_name="ernie4.5-vl",
prompt_batched_token_ids=[processed_req["prompt_token_ids"]],
prompt_tokens_list=case["request"].prompt,
max_tokens_list=[processed_req["max_tokens"]],
)
final_finish_reason = None
chunks = []
async for chunk in generator:
chunks.append(chunk)
if "[DONE]" in chunk:
break
for chunk_str in chunks:
if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str:
try:
json_part = chunk_str.strip().lstrip("data: ")
chunk_dict = json.loads(json_part)
if chunk_dict["choices"][0].get("finish_reason"):
final_finish_reason = chunk_dict["choices"][0]["finish_reason"]
break
except (json.JSONDecodeError, KeyError, IndexError):
continue
self.assertEqual(final_finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败")
@patch.object(data_processor_logger, "info")
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
async def test_completion_create_max_tokens_list_basic(self, mock_api_logger, mock_data_logger):
test_cases = [
{
"name": "单prompt → max_tokens_list长度1",
"request": CompletionRequest(
request_id="test_single_prompt",
model="ernie4.5-vl",
prompt="请介绍人工智能的应用",
stream=False,
max_tokens=10,
),
"mock_max_tokens": 8,
"expected_max_tokens_list_len": 1,
"expected_max_tokens_list": [8],
},
{
"name": "多prompt → max_tokens_list长度2",
"request": CompletionRequest(
request_id="test_multi_prompt",
model="ernie4.5-vl",
prompt=["请介绍Python语言", "请说明机器学习的步骤"],
stream=False,
max_tokens=15,
),
"mock_max_tokens": [12, 10],
"expected_max_tokens_list_len": 2,
"expected_max_tokens_list": [12, 10],
},
]
async def mock_format_and_add_data(current_req_dict):
req_idx = int(current_req_dict["request_id"].split("_")[-1])
if isinstance(case["mock_max_tokens"], list):
current_req_dict["max_tokens"] = case["mock_max_tokens"][req_idx]
else:
current_req_dict["max_tokens"] = case["mock_max_tokens"]
return [101, 102, 103, 104]
self.engine_client.format_and_add_data = AsyncMock(side_effect=mock_format_and_add_data)
async def intercept_generator(**kwargs):
actual_max_tokens_list = kwargs["max_tokens_list"]
self.assertEqual(
len(actual_max_tokens_list),
case["expected_max_tokens_list_len"],
f"列表长度不匹配:实际{len(actual_max_tokens_list)},预期{case['expected_max_tokens_list_len']}",
)
self.assertEqual(
actual_max_tokens_list,
case["expected_max_tokens_list"],
f"列表元素不匹配:实际{actual_max_tokens_list},预期{case['expected_max_tokens_list']}",
)
return CompletionResponse(
id=kwargs["request_id"],
object="text_completion",
created=kwargs["created_time"],
model=kwargs["model_name"],
choices=[],
usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0),
)
self.completion_serving.completion_full_generator = AsyncMock(side_effect=intercept_generator)
for case in test_cases:
with self.subTest(case=case["name"]):
result = await self.completion_serving.create_completion(request=case["request"])
self.assertIsInstance(result, CompletionResponse)

View File

@@ -169,6 +169,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
model_name="test-model",
prompt_token_ids=[1, 2, 3],
prompt_tokens="Hello",
max_tokens=10,
)
chunks = []
@@ -251,6 +252,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
created_time=11,
prompt_batched_token_ids=[[1, 2, 3]],
prompt_tokens_list=["Hello"],
max_tokens_list=[100],
)
chunks = []
@@ -352,6 +354,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
model_name=model_name,
prompt_batched_token_ids=prompt_batched_token_ids,
prompt_tokens_list=prompt_tokens_list,
max_tokens_list=[100, 100],
)
self.assertEqual(actual_response, expected_completion_response)
@@ -449,6 +452,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
num_input_image_tokens = [0, 0]
num_input_video_tokens = [0, 0]
num_image_tokens = [0, 0]
max_tokens_list = [10, 1]
for idx, case in enumerate(test_cases):
actual_choice = await self.chat_serving._create_chat_completion_choice(
@@ -464,6 +468,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
num_image_tokens=num_image_tokens,
logprob_contents=logprob_contents,
response_processor=mock_response_processor,
max_tokens=max_tokens_list[idx],
)
expected = case["expected"]
@@ -554,6 +559,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
model_name="test-model",
prompt_token_ids=[10, 20, 30],
prompt_tokens="Hello",
max_tokens=10,
)
chunks = []
@@ -661,6 +667,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
model_name="test-model",
prompt_batched_token_ids=[[10, 20, 30]],
prompt_tokens_list=["Hello"],
max_tokens_list=[100],
)
chunks = []

View File

@@ -159,6 +159,7 @@ class TestOpenAIServingCompletion(unittest.TestCase):
prompt_batched_token_ids=prompt_batched_token_ids,
completion_batched_token_ids=completion_batched_token_ids,
prompt_tokens_list=["1", "1"],
max_tokens_list=[10, 10],
)
assert completion_response.id == request_id

View File

@@ -78,7 +78,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
)
async def mock_chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list
):
return prompt_token_ids
@@ -106,7 +106,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
)
async def mock_chat_completion_full_generator(
request, request_id, model_name, prompt_token_ids, prompt_tokens
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list
):
return prompt_token_ids