mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] unify max_tokens (#4968)
* unify max tokens * modify and add unit test * modify and add unit test * modify and add unit tests --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -126,6 +126,7 @@ class OpenAIServingChat:
|
||||
request_id = f"chatcmpl-{uuid.uuid4()}"
|
||||
api_server_logger.info(f"create chat completion request: {request_id}")
|
||||
prompt_tokens = None
|
||||
max_tokens = None
|
||||
try:
|
||||
current_req_dict = request.to_dict_for_infer(f"{request_id}_0")
|
||||
if "chat_template" not in current_req_dict:
|
||||
@@ -134,6 +135,7 @@ class OpenAIServingChat:
|
||||
# preprocess the req_dict
|
||||
prompt_token_ids = await self.engine_client.format_and_add_data(current_req_dict)
|
||||
prompt_tokens = current_req_dict.get("prompt_tokens")
|
||||
max_tokens = current_req_dict.get("max_tokens")
|
||||
if isinstance(prompt_token_ids, np.ndarray):
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
except ParameterError as e:
|
||||
@@ -151,12 +153,12 @@ class OpenAIServingChat:
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, request_id, request.model, prompt_token_ids, prompt_tokens
|
||||
request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, request_id, request.model, prompt_token_ids, prompt_tokens
|
||||
request, request_id, request.model, prompt_token_ids, prompt_tokens, max_tokens
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
|
||||
@@ -184,6 +186,7 @@ class OpenAIServingChat:
|
||||
model_name: str,
|
||||
prompt_token_ids: list(),
|
||||
prompt_tokens: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
"""
|
||||
Streaming chat completion generator.
|
||||
@@ -382,9 +385,7 @@ class OpenAIServingChat:
|
||||
work_process_metrics.e2e_request_latency.observe(
|
||||
time.time() - res["metrics"]["request_start_time"]
|
||||
)
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens[idx] != max_tokens:
|
||||
if previous_num_tokens[idx] != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if tool_called[idx]:
|
||||
choice.finish_reason = "tool_calls"
|
||||
@@ -461,6 +462,7 @@ class OpenAIServingChat:
|
||||
model_name: str,
|
||||
prompt_token_ids: list(),
|
||||
prompt_tokens: str,
|
||||
max_tokens: int,
|
||||
):
|
||||
"""
|
||||
Full chat completion generator.
|
||||
@@ -570,6 +572,7 @@ class OpenAIServingChat:
|
||||
num_image_tokens=num_image_tokens,
|
||||
logprob_contents=logprob_contents,
|
||||
response_processor=response_processor,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
choices.append(choice)
|
||||
finally:
|
||||
@@ -620,6 +623,7 @@ class OpenAIServingChat:
|
||||
num_image_tokens: list,
|
||||
logprob_contents: list,
|
||||
response_processor: ChatResponseProcessor,
|
||||
max_tokens: int,
|
||||
) -> ChatCompletionResponseChoice:
|
||||
idx = int(data["request_id"].split("_")[-1])
|
||||
output = data["outputs"]
|
||||
@@ -646,15 +650,13 @@ class OpenAIServingChat:
|
||||
if logprob_contents[idx]:
|
||||
logprobs_full_res = LogProbs(content=logprob_contents[idx])
|
||||
|
||||
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
|
||||
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
|
||||
num_input_video_tokens[idx] = data.get("num_input_video_tokens", 0)
|
||||
num_image_tokens[idx] = output.get("num_image_tokens", 0)
|
||||
|
||||
finish_reason = "stop"
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
if previous_num_tokens != max_tokens:
|
||||
finish_reason = "stop"
|
||||
if output.get("tool_call"):
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
@@ -141,6 +141,7 @@ class OpenAIServingCompletion:
|
||||
api_server_logger.info(f"Start preprocessing request: req_id={request_id}), num_choices={num_choices}")
|
||||
prompt_batched_token_ids = []
|
||||
prompt_tokens_list = []
|
||||
max_tokens_list = []
|
||||
try:
|
||||
if self.max_waiting_time < 0:
|
||||
await self.engine_client.semaphore.acquire()
|
||||
@@ -167,6 +168,7 @@ class OpenAIServingCompletion:
|
||||
prompt_token_ids = prompt_token_ids.tolist()
|
||||
prompt_tokens_list.append(current_req_dict.get("prompt_tokens"))
|
||||
prompt_batched_token_ids.append(prompt_token_ids)
|
||||
max_tokens_list.append(current_req_dict.get("max_tokens"))
|
||||
del current_req_dict
|
||||
except ParameterError as e:
|
||||
api_server_logger.error(f"OpenAIServingCompletion format error: {e}, {e.message}")
|
||||
@@ -191,6 +193,7 @@ class OpenAIServingCompletion:
|
||||
model_name=request.model,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids,
|
||||
prompt_tokens_list=prompt_tokens_list,
|
||||
max_tokens_list=max_tokens_list,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
@@ -202,6 +205,7 @@ class OpenAIServingCompletion:
|
||||
model_name=request.model,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids,
|
||||
prompt_tokens_list=prompt_tokens_list,
|
||||
max_tokens_list=max_tokens_list,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
@@ -224,6 +228,7 @@ class OpenAIServingCompletion:
|
||||
model_name: str,
|
||||
prompt_batched_token_ids: list(),
|
||||
prompt_tokens_list: list(),
|
||||
max_tokens_list: list(),
|
||||
):
|
||||
"""
|
||||
Process the full completion request with multiple choices.
|
||||
@@ -312,6 +317,7 @@ class OpenAIServingCompletion:
|
||||
prompt_batched_token_ids=prompt_batched_token_ids,
|
||||
completion_batched_token_ids=completion_batched_token_ids,
|
||||
prompt_tokens_list=prompt_tokens_list,
|
||||
max_tokens_list=max_tokens_list,
|
||||
)
|
||||
api_server_logger.info(f"Completion response: {res.model_dump_json()}")
|
||||
return res
|
||||
@@ -365,6 +371,7 @@ class OpenAIServingCompletion:
|
||||
model_name: str,
|
||||
prompt_batched_token_ids: list(),
|
||||
prompt_tokens_list: list(),
|
||||
max_tokens_list: list(),
|
||||
):
|
||||
"""
|
||||
Process the stream completion request.
|
||||
@@ -501,7 +508,10 @@ class OpenAIServingCompletion:
|
||||
|
||||
if res["finished"]:
|
||||
choices[-1].finish_reason = self.calc_finish_reason(
|
||||
request.max_tokens, output_tokens[idx], output, tool_called[idx]
|
||||
max_tokens_list[idx // (1 if request.n is None else request.n)],
|
||||
output_tokens[idx],
|
||||
output,
|
||||
tool_called[idx],
|
||||
)
|
||||
|
||||
send_idx = output.get("send_idx")
|
||||
@@ -571,6 +581,7 @@ class OpenAIServingCompletion:
|
||||
prompt_batched_token_ids: list(),
|
||||
completion_batched_token_ids: list(),
|
||||
prompt_tokens_list: list(),
|
||||
max_tokens_list: list(),
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
@@ -607,7 +618,12 @@ class OpenAIServingCompletion:
|
||||
else:
|
||||
token_ids = output["token_ids"]
|
||||
output_text = output["text"]
|
||||
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False)
|
||||
finish_reason = self.calc_finish_reason(
|
||||
max_tokens_list[idx // (1 if request.n is None else request.n)],
|
||||
final_res["output_token_ids"],
|
||||
output,
|
||||
False,
|
||||
)
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
token_ids=token_ids,
|
||||
|
||||
@@ -58,6 +58,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
|
||||
prompt_batched_token_ids=[[1, 2]],
|
||||
completion_batched_token_ids=[[3, 4, 5]],
|
||||
prompt_tokens_list=["test prompt"],
|
||||
max_tokens_list=[100],
|
||||
)
|
||||
|
||||
self.assertEqual(response.choices[0].text, "test prompt generated text")
|
||||
@@ -91,6 +92,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
|
||||
prompt_batched_token_ids=[[1, 2]],
|
||||
completion_batched_token_ids=[[3, 4, 5]],
|
||||
prompt_tokens_list=["test prompt"],
|
||||
max_tokens_list=[100],
|
||||
)
|
||||
self.assertEqual(response.choices[0].text, "decoded_[1, 2, 3] generated text")
|
||||
|
||||
@@ -124,6 +126,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
|
||||
prompt_batched_token_ids=[[1], [2]],
|
||||
completion_batched_token_ids=[[1, 2], [3, 4]],
|
||||
prompt_tokens_list=["prompt1", "prompt2"],
|
||||
max_tokens_list=[100, 100],
|
||||
)
|
||||
|
||||
self.assertEqual(len(response.choices), 2)
|
||||
@@ -160,6 +163,7 @@ class TestCompletionEcho(unittest.IsolatedAsyncioTestCase):
|
||||
prompt_batched_token_ids=[[1], [2]],
|
||||
completion_batched_token_ids=[[1, 2], [3, 4]],
|
||||
prompt_tokens_list=["prompt1", "prompt2"],
|
||||
max_tokens_list=[100, 100],
|
||||
)
|
||||
|
||||
self.assertEqual(len(response.choices), 2)
|
||||
|
||||
648
tests/entrypoints/openai/test_finish_reason.py
Normal file
648
tests/entrypoints/openai/test_finish_reason.py
Normal file
@@ -0,0 +1,648 @@
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DeltaMessage,
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from fastdeploy.input.ernie4_5_vl_processor import Ernie4_5_VLProcessor
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
class TestMultiModalProcessorMaxTokens(IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self):
|
||||
with patch.object(Ernie4_5_VLProcessor, "__init__", return_value=None):
|
||||
self.multi_modal_processor = Ernie4_5_VLProcessor("model_path")
|
||||
self.multi_modal_processor.tokenizer = MagicMock()
|
||||
self.multi_modal_processor.tokenizer.eos_token_id = 102
|
||||
self.multi_modal_processor.tokenizer.pad_token_id = 0
|
||||
self.multi_modal_processor.eos_token_ids = [102]
|
||||
self.multi_modal_processor.eos_token_id_len = 1
|
||||
self.multi_modal_processor.generation_config = MagicMock()
|
||||
self.multi_modal_processor.decode_status = {}
|
||||
self.multi_modal_processor.tool_parser_dict = {}
|
||||
self.multi_modal_processor.ernie4_5_processor = MagicMock()
|
||||
self.multi_modal_processor.ernie4_5_processor.request2ids.return_value = {
|
||||
"input_ids": np.array([101, 9012, 3456, 102])
|
||||
}
|
||||
self.multi_modal_processor.ernie4_5_processor.text2ids.return_value = {
|
||||
"input_ids": np.array([101, 1234, 5678, 102])
|
||||
}
|
||||
self.multi_modal_processor._apply_default_parameters = lambda x: x
|
||||
self.multi_modal_processor.update_stop_seq = Mock(return_value=([], []))
|
||||
self.multi_modal_processor.update_bad_words = Mock(return_value=[])
|
||||
self.multi_modal_processor._check_mm_limits = Mock()
|
||||
self.multi_modal_processor.append_completion_tokens = Mock()
|
||||
self.multi_modal_processor.pack_outputs = lambda x: x
|
||||
|
||||
self.engine_client = Mock()
|
||||
self.engine_client.connection_initialized = False
|
||||
self.engine_client.connection_manager = AsyncMock()
|
||||
self.engine_client.semaphore = Mock()
|
||||
self.engine_client.semaphore.acquire = AsyncMock()
|
||||
self.engine_client.semaphore.release = Mock()
|
||||
self.engine_client.is_master = True
|
||||
self.engine_client.check_model_weight_status = Mock(return_value=False)
|
||||
self.engine_client.enable_mm = True
|
||||
self.engine_client.enable_prefix_caching = False
|
||||
self.engine_client.max_model_len = 20
|
||||
self.engine_client.data_processor = self.multi_modal_processor
|
||||
|
||||
async def mock_add_data(current_req_dict):
|
||||
if current_req_dict.get("max_tokens") is None:
|
||||
current_req_dict["max_tokens"] = self.engine_client.max_model_len - 1
|
||||
current_req_dict["max_tokens"] = min(
|
||||
self.engine_client.max_model_len - 4, max(0, current_req_dict.get("max_tokens"))
|
||||
)
|
||||
|
||||
self.engine_client.add_requests = AsyncMock(side_effect=mock_add_data)
|
||||
|
||||
self.chat_serving = OpenAIServingChat(
|
||||
engine_client=self.engine_client,
|
||||
models=None,
|
||||
pid=123,
|
||||
ips=None,
|
||||
max_waiting_time=30,
|
||||
chat_template="default",
|
||||
enable_mm_output=True,
|
||||
tokenizer_base_url=None,
|
||||
)
|
||||
self.completion_serving = OpenAIServingCompletion(
|
||||
engine_client=self.engine_client, models=None, pid=123, ips=None, max_waiting_time=30
|
||||
)
|
||||
|
||||
def _generate_inference_response(
|
||||
self, request_id: str, output_token_num: int, tool_call: Any = None
|
||||
) -> List[Dict]:
|
||||
outputs = {
|
||||
"text": "这是一张风景图"[:output_token_num],
|
||||
"token_ids": list(range(output_token_num)),
|
||||
"reasoning_content": "推理过程",
|
||||
"num_image_tokens": 0,
|
||||
"num_cached_tokens": 0,
|
||||
"top_logprobs": None,
|
||||
"draft_top_logprobs": None,
|
||||
"tool_call": None,
|
||||
}
|
||||
|
||||
if tool_call:
|
||||
outputs["tool_call"] = [
|
||||
{"index": 0, "type": "function", "function": {"name": tool_call["name"], "arguments": json.dumps({})}}
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"request_id": request_id,
|
||||
"outputs": outputs,
|
||||
"metrics": {"request_start_time": 0.1},
|
||||
"finished": True,
|
||||
"error_msg": None,
|
||||
"output_token_ids": output_token_num,
|
||||
}
|
||||
]
|
||||
|
||||
def _generate_stream_inference_response(
|
||||
self, request_id: str, total_token_num: int, tool_call: Any = None
|
||||
) -> List[Dict]:
|
||||
stream_responses = []
|
||||
for i in range(total_token_num):
|
||||
metrics = {}
|
||||
if i == 0:
|
||||
metrics["first_token_time"] = 0.1
|
||||
metrics["inference_start_time"] = 0.1
|
||||
else:
|
||||
metrics["arrival_time"] = 0.1 * (i + 1)
|
||||
metrics["first_token_time"] = None
|
||||
|
||||
if i == total_token_num - 1:
|
||||
metrics["request_start_time"] = 0.1
|
||||
|
||||
outputs = {
|
||||
"text": chr(ord("a") + i),
|
||||
"token_ids": [i + 1],
|
||||
"top_logprobs": None,
|
||||
"draft_top_logprobs": None,
|
||||
"reasoning_token_num": 0,
|
||||
}
|
||||
|
||||
if tool_call and isinstance(tool_call, dict) and i == total_token_num - 2:
|
||||
delta_msg = DeltaMessage(
|
||||
content="",
|
||||
reasoning_content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"index": 0,
|
||||
"type": "function",
|
||||
"function": {"name": tool_call["name"], "arguments": json.dumps({})},
|
||||
}
|
||||
],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
)
|
||||
outputs["delta_message"] = delta_msg
|
||||
|
||||
frame = [
|
||||
{
|
||||
"request_id": f"{request_id}_0",
|
||||
"error_code": 200,
|
||||
"outputs": outputs,
|
||||
"metrics": metrics,
|
||||
"finished": (i == total_token_num - 1),
|
||||
"error_msg": None,
|
||||
}
|
||||
]
|
||||
stream_responses.append(frame)
|
||||
return stream_responses
|
||||
|
||||
@patch.object(data_processor_logger, "info")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
|
||||
async def test_chat_full_max_tokens(self, mock_data_logger, mock_processor_class, mock_api_logger):
|
||||
test_cases = [
|
||||
{
|
||||
"name": "用户传max_tokens=5,生成数=5→length",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=False,
|
||||
max_tokens=5,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 5,
|
||||
"tool_call": [],
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
{
|
||||
"name": "用户未传max_tokens,生成数=10→stop",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=False,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 10,
|
||||
"tool_call": [],
|
||||
"expected_finish_reason": "stop",
|
||||
},
|
||||
{
|
||||
"name": "用户未传max_tokens,生成数=16→length",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=False,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 16,
|
||||
"tool_call": [],
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
{
|
||||
"name": "用户传max_tokens,生成数=10→stop",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=False,
|
||||
max_tokens=50,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 10,
|
||||
"tool_call": [],
|
||||
"expected_finish_reason": "stop",
|
||||
},
|
||||
{
|
||||
"name": "生成数<max_tokens,触发tool_call→tool_calls",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=False,
|
||||
max_tokens=10,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 8,
|
||||
"tool_call": {"name": "test_tool"},
|
||||
"expected_finish_reason": "tool_calls",
|
||||
},
|
||||
]
|
||||
|
||||
mock_response_queue = AsyncMock()
|
||||
mock_dealer = Mock()
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(
|
||||
return_value=(mock_dealer, mock_response_queue)
|
||||
)
|
||||
|
||||
mock_processor_instance = Mock()
|
||||
mock_processor_instance.enable_multimodal_content.return_value = True
|
||||
|
||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
yield response
|
||||
|
||||
mock_processor_instance.process_response_chat = mock_process_response_chat_async
|
||||
mock_processor_class.return_value = mock_processor_instance
|
||||
|
||||
for case in test_cases:
|
||||
with self.subTest(case=case["name"]):
|
||||
request_dict = {
|
||||
"messages": case["request"].messages,
|
||||
"chat_template": "default",
|
||||
"request_id": "test_chat_0",
|
||||
"max_tokens": case["request"].max_tokens,
|
||||
}
|
||||
await self.engine_client.add_requests(request_dict)
|
||||
processed_req = self.multi_modal_processor.process_request_dict(
|
||||
request_dict, self.engine_client.max_model_len
|
||||
)
|
||||
mock_response_queue.get.side_effect = self._generate_inference_response(
|
||||
request_id="test_chat_0", output_token_num=case["output_token_num"], tool_call=case["tool_call"]
|
||||
)
|
||||
|
||||
result = await self.chat_serving.chat_completion_full_generator(
|
||||
request=case["request"],
|
||||
request_id="test_chat",
|
||||
model_name="ernie4.5-vl",
|
||||
prompt_token_ids=processed_req["prompt_token_ids"],
|
||||
prompt_tokens="描述这张图片",
|
||||
max_tokens=processed_req["max_tokens"],
|
||||
)
|
||||
self.assertEqual(
|
||||
result.choices[0].finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败"
|
||||
)
|
||||
|
||||
@patch.object(data_processor_logger, "info")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
|
||||
async def test_completion_full_max_tokens(self, mock_api_logger, mock_data_logger):
|
||||
test_cases = [
|
||||
{
|
||||
"name": "用户传max_tokens=6,生成数=6→length",
|
||||
"request": CompletionRequest(
|
||||
request_id="test_completion",
|
||||
model="ernie4.5-vl",
|
||||
prompt="描述这张图片:<image>xxx</image>",
|
||||
stream=False,
|
||||
max_tokens=6,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 6,
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
{
|
||||
"name": "用户未传max_tokens,生成数=12→stop",
|
||||
"request": CompletionRequest(
|
||||
request_id="test_completion",
|
||||
model="ernie4.5-vl",
|
||||
prompt="描述这张图片:<image>xxx</image>",
|
||||
stream=False,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 12,
|
||||
"expected_finish_reason": "stop",
|
||||
},
|
||||
{
|
||||
"name": "用户传max_tokens=20(修正为16),生成数=16→length",
|
||||
"request": CompletionRequest(
|
||||
request_id="test_completion",
|
||||
model="ernie4.5-vl",
|
||||
prompt="描述这张图片:<image>xxx</image>",
|
||||
stream=False,
|
||||
max_tokens=20,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"output_token_num": 16,
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
]
|
||||
|
||||
mock_dealer = Mock()
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
|
||||
|
||||
for case in test_cases:
|
||||
with self.subTest(case=case["name"]):
|
||||
request_dict = {
|
||||
"prompt": case["request"].prompt,
|
||||
"request_id": "test_completion",
|
||||
"multimodal_data": {"image": ["xxx"]},
|
||||
"max_tokens": case["request"].max_tokens,
|
||||
}
|
||||
await self.engine_client.add_requests(request_dict)
|
||||
processed_req = self.multi_modal_processor.process_request_dict(
|
||||
request_dict, self.engine_client.max_model_len
|
||||
)
|
||||
self.engine_client.data_processor.process_response_dict = (
|
||||
lambda data, stream, include_stop_str_in_output: data
|
||||
)
|
||||
mock_response_queue = AsyncMock()
|
||||
mock_response_queue.get.side_effect = lambda: [
|
||||
{
|
||||
"request_id": "test_completion_0",
|
||||
"error_code": 200,
|
||||
"outputs": {
|
||||
"text": "这是一张风景图"[: case["output_token_num"]],
|
||||
"token_ids": list(range(case["output_token_num"])),
|
||||
"top_logprobs": None,
|
||||
"draft_top_logprobs": None,
|
||||
},
|
||||
"metrics": {"request_start_time": 0.1},
|
||||
"finished": True,
|
||||
"error_msg": None,
|
||||
"output_token_ids": case["output_token_num"],
|
||||
}
|
||||
]
|
||||
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
|
||||
|
||||
result = await self.completion_serving.completion_full_generator(
|
||||
request=case["request"],
|
||||
num_choices=1,
|
||||
request_id="test_completion",
|
||||
created_time=1699999999,
|
||||
model_name="ernie4.5-vl",
|
||||
prompt_batched_token_ids=[processed_req["prompt_token_ids"]],
|
||||
prompt_tokens_list=[case["request"].prompt],
|
||||
max_tokens_list=[processed_req["max_tokens"]],
|
||||
)
|
||||
|
||||
self.assertIsInstance(result, CompletionResponse)
|
||||
self.assertEqual(result.choices[0].finish_reason, case["expected_finish_reason"])
|
||||
|
||||
@patch.object(data_processor_logger, "info")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.ChatResponseProcessor")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_chat.api_server_logger")
|
||||
async def test_chat_stream_max_tokens(self, mock_api_logger, mock_processor_class, mock_data_logger):
|
||||
test_cases = [
|
||||
{
|
||||
"name": "流式-生成数=8(等于max_tokens)→length",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=True,
|
||||
max_tokens=8,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"total_token_num": 8,
|
||||
"tool_call": None,
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
{
|
||||
"name": "流式-生成数=6(小于max_tokens)+tool_call→tool_calls",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"total_token_num": 3,
|
||||
"tool_call": {"name": "test_tool"},
|
||||
"expected_finish_reason": "tool_calls",
|
||||
},
|
||||
{
|
||||
"name": "流式-生成数=7(小于max_tokens)无tool_call→stop",
|
||||
"request": ChatCompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
messages=[{"role": "user", "content": "描述这张图片"}],
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"total_token_num": 7,
|
||||
"tool_call": None,
|
||||
"expected_finish_reason": "stop",
|
||||
},
|
||||
]
|
||||
|
||||
mock_dealer = Mock()
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
|
||||
|
||||
mock_processor_instance = Mock()
|
||||
mock_processor_instance.enable_multimodal_content.return_value = False
|
||||
|
||||
async def mock_process_response_chat_async(response, stream, enable_thinking, include_stop_str_in_output):
|
||||
if isinstance(response, list):
|
||||
for res in response:
|
||||
yield res
|
||||
else:
|
||||
yield response
|
||||
|
||||
mock_processor_instance.process_response_chat = mock_process_response_chat_async
|
||||
mock_processor_class.return_value = mock_processor_instance
|
||||
|
||||
for case in test_cases:
|
||||
with self.subTest(case=case["name"]):
|
||||
request_dict = {
|
||||
"messages": case["request"].messages,
|
||||
"chat_template": "default",
|
||||
"request_id": "test_chat_stream_0",
|
||||
"max_tokens": case["request"].max_tokens,
|
||||
}
|
||||
await self.engine_client.add_requests(request_dict)
|
||||
processed_req = self.multi_modal_processor.process_request_dict(
|
||||
request_dict, self.engine_client.max_model_len
|
||||
)
|
||||
|
||||
self.engine_client.data_processor.process_response_dict = (
|
||||
lambda data, stream, include_stop_str_in_output: data
|
||||
)
|
||||
|
||||
mock_response_queue = AsyncMock()
|
||||
stream_responses = self._generate_stream_inference_response(
|
||||
request_id="test_chat_stream_0_0",
|
||||
total_token_num=case["total_token_num"],
|
||||
tool_call=case["tool_call"],
|
||||
)
|
||||
mock_response_queue.get.side_effect = stream_responses
|
||||
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
|
||||
|
||||
generator = self.chat_serving.chat_completion_stream_generator(
|
||||
request=case["request"],
|
||||
request_id="test_chat_stream_0",
|
||||
model_name="ernie4.5-vl",
|
||||
prompt_token_ids=processed_req["prompt_token_ids"],
|
||||
prompt_tokens="描述这张图片",
|
||||
max_tokens=processed_req["max_tokens"],
|
||||
)
|
||||
|
||||
final_finish_reason = None
|
||||
chunks = []
|
||||
async for chunk in generator:
|
||||
chunks.append(chunk)
|
||||
if "[DONE]" in chunk:
|
||||
break
|
||||
|
||||
for chunk_str in chunks:
|
||||
if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str:
|
||||
try:
|
||||
json_part = chunk_str.strip().lstrip("data: ").rstrip("\n\n")
|
||||
chunk_dict = json.loads(json_part)
|
||||
if chunk_dict.get("choices") and len(chunk_dict["choices"]) > 0:
|
||||
finish_reason = chunk_dict["choices"][0].get("finish_reason")
|
||||
if finish_reason:
|
||||
final_finish_reason = finish_reason
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
continue
|
||||
|
||||
self.assertEqual(final_finish_reason, case["expected_finish_reason"])
|
||||
|
||||
@patch.object(data_processor_logger, "info")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
|
||||
async def test_completion_stream_max_tokens(self, mock_api_logger, mock_data_logger):
|
||||
test_cases = [
|
||||
{
|
||||
"name": "流式-生成数=7(等于max_tokens)→length",
|
||||
"request": CompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
prompt=["描述这张图片:<image>xxx</image>"],
|
||||
stream=True,
|
||||
max_tokens=7,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"total_token_num": 7,
|
||||
"expected_finish_reason": "length",
|
||||
},
|
||||
{
|
||||
"name": "流式-生成数=9(小于max_tokens)→stop",
|
||||
"request": CompletionRequest(
|
||||
model="ernie4.5-vl",
|
||||
prompt=["描述这张图片:<image>xxx</image>"],
|
||||
stream=True,
|
||||
max_tokens=15,
|
||||
return_token_ids=True,
|
||||
),
|
||||
"total_token_num": 9,
|
||||
"expected_finish_reason": "stop",
|
||||
},
|
||||
]
|
||||
|
||||
mock_dealer = Mock()
|
||||
self.engine_client.connection_manager.get_connection = AsyncMock(return_value=(mock_dealer, AsyncMock()))
|
||||
|
||||
for case in test_cases:
|
||||
with self.subTest(case=case["name"]):
|
||||
request_dict = {
|
||||
"prompt": case["request"].prompt,
|
||||
"multimodal_data": {"image": ["xxx"]},
|
||||
"request_id": "test_completion_stream_0",
|
||||
"max_tokens": case["request"].max_tokens,
|
||||
}
|
||||
await self.engine_client.add_requests(request_dict)
|
||||
processed_req = self.multi_modal_processor.process_request_dict(
|
||||
request_dict, self.engine_client.max_model_len
|
||||
)
|
||||
self.engine_client.data_processor.process_response_dict = (
|
||||
lambda data, stream, include_stop_str_in_output: data
|
||||
)
|
||||
|
||||
mock_response_queue = AsyncMock()
|
||||
stream_responses = self._generate_stream_inference_response(
|
||||
request_id="test_completion_stream_0", total_token_num=case["total_token_num"]
|
||||
)
|
||||
mock_response_queue.get.side_effect = stream_responses
|
||||
self.engine_client.connection_manager.get_connection.return_value = (mock_dealer, mock_response_queue)
|
||||
|
||||
generator = self.completion_serving.completion_stream_generator(
|
||||
request=case["request"],
|
||||
num_choices=1,
|
||||
created_time=0,
|
||||
request_id="test_completion_stream",
|
||||
model_name="ernie4.5-vl",
|
||||
prompt_batched_token_ids=[processed_req["prompt_token_ids"]],
|
||||
prompt_tokens_list=case["request"].prompt,
|
||||
max_tokens_list=[processed_req["max_tokens"]],
|
||||
)
|
||||
|
||||
final_finish_reason = None
|
||||
chunks = []
|
||||
async for chunk in generator:
|
||||
chunks.append(chunk)
|
||||
if "[DONE]" in chunk:
|
||||
break
|
||||
|
||||
for chunk_str in chunks:
|
||||
if chunk_str.startswith("data: ") and "[DONE]" not in chunk_str:
|
||||
try:
|
||||
json_part = chunk_str.strip().lstrip("data: ")
|
||||
chunk_dict = json.loads(json_part)
|
||||
if chunk_dict["choices"][0].get("finish_reason"):
|
||||
final_finish_reason = chunk_dict["choices"][0]["finish_reason"]
|
||||
break
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
continue
|
||||
|
||||
self.assertEqual(final_finish_reason, case["expected_finish_reason"], f"场景 {case['name']} 失败")
|
||||
|
||||
@patch.object(data_processor_logger, "info")
|
||||
@patch("fastdeploy.entrypoints.openai.serving_completion.api_server_logger")
|
||||
async def test_completion_create_max_tokens_list_basic(self, mock_api_logger, mock_data_logger):
|
||||
test_cases = [
|
||||
{
|
||||
"name": "单prompt → max_tokens_list长度1",
|
||||
"request": CompletionRequest(
|
||||
request_id="test_single_prompt",
|
||||
model="ernie4.5-vl",
|
||||
prompt="请介绍人工智能的应用",
|
||||
stream=False,
|
||||
max_tokens=10,
|
||||
),
|
||||
"mock_max_tokens": 8,
|
||||
"expected_max_tokens_list_len": 1,
|
||||
"expected_max_tokens_list": [8],
|
||||
},
|
||||
{
|
||||
"name": "多prompt → max_tokens_list长度2",
|
||||
"request": CompletionRequest(
|
||||
request_id="test_multi_prompt",
|
||||
model="ernie4.5-vl",
|
||||
prompt=["请介绍Python语言", "请说明机器学习的步骤"],
|
||||
stream=False,
|
||||
max_tokens=15,
|
||||
),
|
||||
"mock_max_tokens": [12, 10],
|
||||
"expected_max_tokens_list_len": 2,
|
||||
"expected_max_tokens_list": [12, 10],
|
||||
},
|
||||
]
|
||||
|
||||
async def mock_format_and_add_data(current_req_dict):
|
||||
req_idx = int(current_req_dict["request_id"].split("_")[-1])
|
||||
if isinstance(case["mock_max_tokens"], list):
|
||||
current_req_dict["max_tokens"] = case["mock_max_tokens"][req_idx]
|
||||
else:
|
||||
current_req_dict["max_tokens"] = case["mock_max_tokens"]
|
||||
return [101, 102, 103, 104]
|
||||
|
||||
self.engine_client.format_and_add_data = AsyncMock(side_effect=mock_format_and_add_data)
|
||||
|
||||
async def intercept_generator(**kwargs):
|
||||
actual_max_tokens_list = kwargs["max_tokens_list"]
|
||||
self.assertEqual(
|
||||
len(actual_max_tokens_list),
|
||||
case["expected_max_tokens_list_len"],
|
||||
f"列表长度不匹配:实际{len(actual_max_tokens_list)},预期{case['expected_max_tokens_list_len']}",
|
||||
)
|
||||
self.assertEqual(
|
||||
actual_max_tokens_list,
|
||||
case["expected_max_tokens_list"],
|
||||
f"列表元素不匹配:实际{actual_max_tokens_list},预期{case['expected_max_tokens_list']}",
|
||||
)
|
||||
return CompletionResponse(
|
||||
id=kwargs["request_id"],
|
||||
object="text_completion",
|
||||
created=kwargs["created_time"],
|
||||
model=kwargs["model_name"],
|
||||
choices=[],
|
||||
usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
||||
)
|
||||
|
||||
self.completion_serving.completion_full_generator = AsyncMock(side_effect=intercept_generator)
|
||||
|
||||
for case in test_cases:
|
||||
with self.subTest(case=case["name"]):
|
||||
result = await self.completion_serving.create_completion(request=case["request"])
|
||||
self.assertIsInstance(result, CompletionResponse)
|
||||
@@ -169,6 +169,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
model_name="test-model",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_tokens="Hello",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
@@ -251,6 +252,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
created_time=11,
|
||||
prompt_batched_token_ids=[[1, 2, 3]],
|
||||
prompt_tokens_list=["Hello"],
|
||||
max_tokens_list=[100],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
@@ -352,6 +354,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
model_name=model_name,
|
||||
prompt_batched_token_ids=prompt_batched_token_ids,
|
||||
prompt_tokens_list=prompt_tokens_list,
|
||||
max_tokens_list=[100, 100],
|
||||
)
|
||||
|
||||
self.assertEqual(actual_response, expected_completion_response)
|
||||
@@ -449,6 +452,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
num_input_image_tokens = [0, 0]
|
||||
num_input_video_tokens = [0, 0]
|
||||
num_image_tokens = [0, 0]
|
||||
max_tokens_list = [10, 1]
|
||||
|
||||
for idx, case in enumerate(test_cases):
|
||||
actual_choice = await self.chat_serving._create_chat_completion_choice(
|
||||
@@ -464,6 +468,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
num_image_tokens=num_image_tokens,
|
||||
logprob_contents=logprob_contents,
|
||||
response_processor=mock_response_processor,
|
||||
max_tokens=max_tokens_list[idx],
|
||||
)
|
||||
|
||||
expected = case["expected"]
|
||||
@@ -554,6 +559,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
model_name="test-model",
|
||||
prompt_token_ids=[10, 20, 30],
|
||||
prompt_tokens="Hello",
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
chunks = []
|
||||
@@ -661,6 +667,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
|
||||
model_name="test-model",
|
||||
prompt_batched_token_ids=[[10, 20, 30]],
|
||||
prompt_tokens_list=["Hello"],
|
||||
max_tokens_list=[100],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
|
||||
@@ -159,6 +159,7 @@ class TestOpenAIServingCompletion(unittest.TestCase):
|
||||
prompt_batched_token_ids=prompt_batched_token_ids,
|
||||
completion_batched_token_ids=completion_batched_token_ids,
|
||||
prompt_tokens_list=["1", "1"],
|
||||
max_tokens_list=[10, 10],
|
||||
)
|
||||
|
||||
assert completion_response.id == request_id
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
async def mock_chat_completion_full_generator(
|
||||
request, request_id, model_name, prompt_token_ids, prompt_tokens
|
||||
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list
|
||||
):
|
||||
return prompt_token_ids
|
||||
|
||||
@@ -106,7 +106,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
)
|
||||
|
||||
async def mock_chat_completion_full_generator(
|
||||
request, request_id, model_name, prompt_token_ids, prompt_tokens
|
||||
request, request_id, model_name, prompt_token_ids, prompt_tokens, max_tokens_list
|
||||
):
|
||||
return prompt_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user