[Feature] Support include_stop_str_in_output (#2930)

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-21 10:58:32 +08:00
committed by GitHub
parent b89f083004
commit f941124402
4 changed files with 74 additions and 8 deletions

View File

@@ -119,6 +119,7 @@ class OpenAIServingChat:
num_choices = 1
max_streaming_response_tokens = 1
enable_thinking = None
include_stop_str_in_output = False
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
@@ -146,6 +147,7 @@ class OpenAIServingChat:
current_waiting_time = 0
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -169,7 +171,7 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking)
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
if res['metrics']['first_token_time'] is not None:
arrival_time = res['metrics']['first_token_time']
@@ -303,6 +305,7 @@ class OpenAIServingChat:
created_time = int(time.time())
final_res = None
enable_thinking = None
include_stop_str_in_output = False
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
@@ -335,8 +338,9 @@ class OpenAIServingChat:
raise ValueError("{}".format(data["error_msg"]))
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking)
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
# api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response

View File

@@ -248,7 +248,7 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -283,7 +283,7 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(

View File

@@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
@@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
response_dict, enable_thinking=enable_thinking, **kwargs)
else:
return self.process_response_dict_normal(
response_dict=response_dict, enable_thinking=enable_thinking)
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
def text2ids(self, text, max_model_len, raw_request=True):
"""

View File

@@ -314,3 +314,65 @@ def test_streaming(openai_client, capsys):
for chunk in response:
output.append(chunk.choices[0].text)
assert len(output) > 0
def test_non_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert response.choices[0].message.content.endswith("</s>")
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert not response.choices[0].message.content.endswith("</s>")
def test_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token == "</s>"
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token != "</s>"