[Feature] Support include_stop_str_in_output in chat/completion (#2910)

* [Feature] Support include_stop_str_in_output in chat/completion

* Add ci test for include_stop_str_in_output

* Update version of openai

* Fix ci test

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-18 16:59:18 +08:00
committed by GitHub
parent 6efad14b95
commit fbe3547c95
5 changed files with 82 additions and 16 deletions

View File

@@ -119,6 +119,7 @@ class OpenAIServingChat:
num_choices = 1
max_streaming_response_tokens = 1
enable_thinking = None
include_stop_str_in_output = False
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
@@ -146,6 +147,7 @@ class OpenAIServingChat:
current_waiting_time = 0
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
while num_choices > 0:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -169,7 +171,7 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking)
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
if res['metrics']['first_token_time'] is not None:
arrival_time = res['metrics']['first_token_time']
@@ -303,6 +305,7 @@ class OpenAIServingChat:
created_time = int(time.time())
final_res = None
enable_thinking = None
include_stop_str_in_output = False
try:
dealer = await aiozmq.create_zmq_stream(
zmq.DEALER,
@@ -335,8 +338,9 @@ class OpenAIServingChat:
raise ValueError("{}".format(data["error_msg"]))
if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking)
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
# api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response

View File

@@ -248,7 +248,7 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -283,7 +283,7 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(

View File

@@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
@@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
response_dict, enable_thinking=enable_thinking, **kwargs)
else:
return self.process_response_dict_normal(
response_dict=response_dict, enable_thinking=enable_thinking)
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
def text2ids(self, text, max_model_len, raw_request=True):
"""

View File

@@ -5,7 +5,7 @@ flake8
ruamel.yaml
zmq
aiozmq
openai
openai>=1.93.0
tqdm
pynvml
uvicorn

View File

@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import requests
import time
import subprocess
import socket
import os
import signal
import socket
import subprocess
import sys
import time
import openai
import pytest
import requests
# Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
@@ -296,7 +297,6 @@ def test_non_streaming(openai_client):
assert hasattr(response, 'choices')
assert len(response.choices) > 0
def test_streaming(openai_client, capsys):
"""
Test streaming functionality with the local service
@@ -314,3 +314,65 @@ def test_streaming(openai_client, capsys):
for chunk in response:
output.append(chunk.choices[0].text)
assert len(output) > 0
def test_non_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert response.choices[0].message.content.endswith("</s>")
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert not response.choices[0].message.content.endswith("</s>")
def test_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token == "</s>"
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token != "</s>"