[BugFix]Fix finish reason (#4543)

* fix finish reason

* add unit test

* add unit test

* fix unie test

* fix unit test
This commit is contained in:
luukunn
2025-10-23 14:04:43 +08:00
committed by GitHub
parent ac4f5ca272
commit bbf06b9ff7
3 changed files with 163 additions and 11 deletions

View File

@@ -521,8 +521,10 @@ class OpenAIServingChat:
if data["finished"]:
num_choices -= 1
choice = await self._create_chat_completion_choice(
data=data,
output=output,
index=idx,
request=request,
previous_num_tokens=previous_num_tokens[idx],
prompt_token_ids=prompt_token_ids,
prompt_tokens=prompt_tokens,
completion_token_ids=completion_token_ids[idx],
@@ -557,8 +559,10 @@ class OpenAIServingChat:
async def _create_chat_completion_choice(
self,
data: dict,
output: dict,
index: int,
request: ChatCompletionRequest,
previous_num_tokens: int,
prompt_token_ids: list,
prompt_tokens: str,
completion_token_ids: list,
@@ -566,9 +570,6 @@ class OpenAIServingChat:
logprob_contents: list,
response_processor: ChatResponseProcessor,
) -> ChatCompletionResponseChoice:
idx = int(data["request_id"].split("_")[-1])
output = data["outputs"]
previous_num_tokens = len(data["outputs"]["token_ids"])
if output is not None and output.get("metrics") and output["metrics"].get("request_start_time"):
work_process_metrics.e2e_request_latency.observe(
@@ -589,12 +590,12 @@ class OpenAIServingChat:
message.content = output["text"]
logprobs_full_res = None
if logprob_contents[idx]:
logprobs_full_res = LogProbs(content=logprob_contents[idx])
if logprob_contents[index]:
logprobs_full_res = LogProbs(content=logprob_contents[index])
has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None
max_tokens = request.max_completion_tokens or request.max_tokens
num_cached_tokens[idx] = output.get("num_cached_tokens", 0)
num_cached_tokens[index] = output.get("num_cached_tokens", 0)
finish_reason = "stop"
if has_no_token_limit or previous_num_tokens != max_tokens:
@@ -607,7 +608,7 @@ class OpenAIServingChat:
finish_reason = "recover_stop"
return ChatCompletionResponseChoice(
index=idx,
index=index,
message=message,
logprobs=logprobs_full_res,
finish_reason=finish_reason,

View File

@@ -287,6 +287,69 @@ def test_non_streaming_chat(openai_client):
assert hasattr(response.choices[0].message, "content")
def test_non_streaming_chat_finish_reason(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_tokens=5,
stream=False,
)
assert hasattr(response, "choices")
assert response.choices[0].finish_reason == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_completion_tokens=5,
stream=False,
)
assert hasattr(response, "choices")
assert response.choices[0].finish_reason == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_tokens=5,
stream=False,
n=2,
)
assert hasattr(response, "choices")
for choice in response.choices:
assert choice.finish_reason == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_completion_tokens=5,
stream=False,
n=2,
)
assert hasattr(response, "choices")
for choice in response.choices:
assert choice.finish_reason == "length"
# Streaming test
def test_streaming_chat(openai_client, capsys):
"""
@@ -1281,6 +1344,89 @@ def test_streaming_completion_with_bad_words(openai_client, capsys):
assert not any(ids in output_ids_2 for ids in bad_token_ids)
def test_streaming_chat_finish_reason(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_tokens=5,
stream=True,
)
for chunk in response:
last_token = chunk.choices[0].finish_reason
assert last_token == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_completion_tokens=5,
stream=True,
)
for chunk in response:
last_token = chunk.choices[0].finish_reason
assert last_token == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_completion_tokens=5,
stream=True,
n=2,
)
finish_reason_1 = ""
finish_reason_1 = ""
for chunk in response:
last_token = chunk.choices[0].finish_reason
if last_token:
if chunk.choices[0].index == 0:
finish_reason_1 = last_token
else:
finish_reason_2 = last_token
assert finish_reason_1 == "length"
assert finish_reason_2 == "length"
response = openai_client.chat.completions.create(
model="default",
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=1,
max_tokens=5,
stream=True,
n=2,
)
finish_reason_1 = ""
finish_reason_1 = ""
for chunk in response:
last_token = chunk.choices[0].finish_reason
if last_token:
if chunk.choices[0].index == 0:
finish_reason_1 = last_token
else:
finish_reason_2 = last_token
assert finish_reason_1 == "length"
assert finish_reason_2 == "length"
def test_profile_reset_block_num():
"""测试profile reset_block_num功能与baseline diff不能超过5%"""
log_file = "./log/config.log"

View File

@@ -391,6 +391,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
"raw_prediction": "raw_answer_0",
},
"finished": True,
"previous_num_tokens": 2,
},
"mock_request": ChatCompletionRequest(
model="test", messages=[], return_token_ids=True, max_tokens=10, n=2
@@ -417,6 +418,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
"raw_prediction": None,
},
"finished": True,
"previous_num_tokens": 1,
},
"mock_request": ChatCompletionRequest(
model="test", messages=[], return_token_ids=True, max_tokens=5, n=2
@@ -435,7 +437,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
prompt_token_ids = [1, 2]
prompt_tokens = "test_prompt"
logprob_contents = [[], []]
logprob_contents = [[{"token": "hello", "logprob": 0.1}], [{"token": "hello", "logprob": 0.1}]]
mock_response_processor = Mock()
mock_response_processor.enable_multimodal_content.return_value = False
completion_token_ids = [[], []]
@@ -443,8 +445,10 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
for idx, case in enumerate(test_cases):
actual_choice = await self.chat_serving._create_chat_completion_choice(
data=case["test_data"],
output=case["test_data"]["outputs"],
index=idx,
request=case["mock_request"],
previous_num_tokens=case["test_data"]["previous_num_tokens"],
prompt_token_ids=prompt_token_ids,
prompt_tokens=prompt_tokens,
completion_token_ids=completion_token_ids[idx],
@@ -465,6 +469,7 @@ class TestMaxStreamingResponseTokens(IsolatedAsyncioTestCase):
self.assertEqual(num_cached_tokens[expected["index"]], expected["num_cached_tokens"])
self.assertEqual(actual_choice.finish_reason, expected["finish_reason"])
assert actual_choice.logprobs is not None
if __name__ == "__main__":