[Feature] Support include_stop_str_in_output (#2919)

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-18 19:43:19 +08:00
committed by GitHub
parent c71d955e9c
commit e421d51001
4 changed files with 80 additions and 15 deletions

View File

@@ -104,6 +104,7 @@ class OpenAIServingChat:
num_choices = 1 num_choices = 1
max_streaming_response_tokens = 1 max_streaming_response_tokens = 1
enable_thinking = None enable_thinking = None
include_stop_str_in_output = False
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1: if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"] max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
@@ -152,8 +153,9 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
if request.metadata is not None: if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking") enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking) res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
if res['metrics']['first_token_time'] is not None: if res['metrics']['first_token_time'] is not None:
arrival_time = res['metrics']['first_token_time'] arrival_time = res['metrics']['first_token_time']
@@ -282,6 +284,7 @@ class OpenAIServingChat:
created_time = int(time.time()) created_time = int(time.time())
final_res = None final_res = None
enable_thinking = None enable_thinking = None
include_stop_str_in_output = False
try: try:
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(
zmq.DEALER, zmq.DEALER,
@@ -312,8 +315,9 @@ class OpenAIServingChat:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
if request.metadata is not None: if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking") enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict( data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking) data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
# api_server_logger.debug(f"Client {request_id} received: {data}") # api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"]) previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response # The logprob for handling the response

View File

@@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
if request.prompt_token_ids is None or len( if request.prompt_token_ids is None or len(
request.prompt_token_ids) == 0: request.prompt_token_ids) == 0:
system = request.get("system")
if request.prompt is None and request.messages is None: if request.prompt is None and request.messages is None:
raise ValueError( raise ValueError(
f"The request should have `input_ids`, `text` or `messages`: {request}.") f"The request should have `input_ids`, `text` or `messages`: {request}.")
@@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
request['stop_token_ids'] = stop_seqs request['stop_token_ids'] = stop_seqs
request['stop_seqs_len'] = stop_seqs_len request['stop_seqs_len'] = stop_seqs_len
system = request.get("system")
# 处理prompt_token_ids # 处理prompt_token_ids
if not request.get('prompt_token_ids'): if not request.get('prompt_token_ids'):
if request.get('prompt') is None and request.get( if request.get('prompt') is None and request.get(
@@ -249,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"] is_end = response_dict["finished"]
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -284,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens( delta_text, previous_token_ids, previous_texts = self.ids2tokens(

View File

@@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"] is_end = response_dict["finished"]
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens( delta_text, previous_token_ids, previous_texts = self.ids2tokens(
@@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
response_dict, enable_thinking=enable_thinking, **kwargs) response_dict, enable_thinking=enable_thinking, **kwargs)
else: else:
return self.process_response_dict_normal( return self.process_response_dict_normal(
response_dict=response_dict, enable_thinking=enable_thinking) response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
def text2ids(self, text, max_model_len, raw_request=True): def text2ids(self, text, max_model_len, raw_request=True):
""" """

View File

@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
import requests
import time
import subprocess
import socket
import os import os
import signal import signal
import socket
import subprocess
import sys import sys
import time
import openai import openai
import pytest
import requests
# Read ports from environment variables; use default values if not set # Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
@@ -313,4 +314,66 @@ def test_streaming(openai_client, capsys):
output = [] output = []
for chunk in response: for chunk in response:
output.append(chunk.choices[0].text) output.append(chunk.choices[0].text)
assert len(output) > 0 assert len(output) > 0
def test_non_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert response.choices[0].message.content.endswith("</s>")
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert not response.choices[0].message.content.endswith("</s>")
def test_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token == "</s>"
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token != "</s>"