[Feature] Support include_stop_str_in_output in chat/completion (#2910)

* [Feature] Support include_stop_str_in_output in chat/completion

* Add ci test for include_stop_str_in_output

* Update version of openai

* Fix ci test

---------

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-18 16:59:18 +08:00
committed by GitHub
parent 6efad14b95
commit fbe3547c95
5 changed files with 82 additions and 16 deletions

View File

@@ -119,6 +119,7 @@ class OpenAIServingChat:
num_choices = 1 num_choices = 1
max_streaming_response_tokens = 1 max_streaming_response_tokens = 1
enable_thinking = None enable_thinking = None
include_stop_str_in_output = False
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1: if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"] max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
@@ -146,6 +147,7 @@ class OpenAIServingChat:
current_waiting_time = 0 current_waiting_time = 0
if request.metadata is not None: if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking") enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
while num_choices > 0: while num_choices > 0:
try: try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10) raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -169,7 +171,7 @@ class OpenAIServingChat:
raise ValueError("{}".format(res["error_msg"])) raise ValueError("{}".format(res["error_msg"]))
self.engine_client.data_processor.process_response_dict( self.engine_client.data_processor.process_response_dict(
res, stream=True, enable_thinking=enable_thinking) res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
if res['metrics']['first_token_time'] is not None: if res['metrics']['first_token_time'] is not None:
arrival_time = res['metrics']['first_token_time'] arrival_time = res['metrics']['first_token_time']
@@ -303,6 +305,7 @@ class OpenAIServingChat:
created_time = int(time.time()) created_time = int(time.time())
final_res = None final_res = None
enable_thinking = None enable_thinking = None
include_stop_str_in_output = False
try: try:
dealer = await aiozmq.create_zmq_stream( dealer = await aiozmq.create_zmq_stream(
zmq.DEALER, zmq.DEALER,
@@ -335,8 +338,9 @@ class OpenAIServingChat:
raise ValueError("{}".format(data["error_msg"])) raise ValueError("{}".format(data["error_msg"]))
if request.metadata is not None: if request.metadata is not None:
enable_thinking = request.metadata.get("enable_thinking") enable_thinking = request.metadata.get("enable_thinking")
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
data = self.engine_client.data_processor.process_response_dict( data = self.engine_client.data_processor.process_response_dict(
data, stream=False, enable_thinking=enable_thinking) data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
# api_server_logger.debug(f"Client {request_id} received: {data}") # api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"]) previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response # The logprob for handling the response

View File

@@ -248,7 +248,7 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"] is_end = response_dict["finished"]
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -283,7 +283,7 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens( delta_text, previous_token_ids, previous_texts = self.ids2tokens(

View File

@@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"] is_end = response_dict["finished"]
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
req_id = response_dict["request_id"] req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"] token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0: if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id: if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens( delta_text, previous_token_ids, previous_texts = self.ids2tokens(
@@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
response_dict, enable_thinking=enable_thinking, **kwargs) response_dict, enable_thinking=enable_thinking, **kwargs)
else: else:
return self.process_response_dict_normal( return self.process_response_dict_normal(
response_dict=response_dict, enable_thinking=enable_thinking) response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
def text2ids(self, text, max_model_len, raw_request=True): def text2ids(self, text, max_model_len, raw_request=True):
""" """

View File

@@ -5,7 +5,7 @@ flake8
ruamel.yaml ruamel.yaml
zmq zmq
aiozmq aiozmq
openai openai>=1.93.0
tqdm tqdm
pynvml pynvml
uvicorn uvicorn
@@ -36,4 +36,4 @@ opentelemetry-instrumentation-redis
opentelemetry-instrumentation-mysql opentelemetry-instrumentation-mysql
opentelemetry-distro  opentelemetry-distro 
opentelemetry-exporter-otlp opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi opentelemetry-instrumentation-fastapi

View File

@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pytest
import requests
import time
import subprocess
import socket
import os import os
import signal import signal
import socket
import subprocess
import sys import sys
import time
import openai import openai
import pytest
import requests
# Read ports from environment variables; use default values if not set # Read ports from environment variables; use default values if not set
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
@@ -296,7 +297,6 @@ def test_non_streaming(openai_client):
assert hasattr(response, 'choices') assert hasattr(response, 'choices')
assert len(response.choices) > 0 assert len(response.choices) > 0
def test_streaming(openai_client, capsys): def test_streaming(openai_client, capsys):
""" """
Test streaming functionality with the local service Test streaming functionality with the local service
@@ -313,4 +313,66 @@ def test_streaming(openai_client, capsys):
output = [] output = []
for chunk in response: for chunk in response:
output.append(chunk.choices[0].text) output.append(chunk.choices[0].text)
assert len(output) > 0 assert len(output) > 0
def test_non_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert response.choices[0].message.content.endswith("</s>")
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=False,
)
# Assertions to check the response structure
assert hasattr(response, 'choices')
assert len(response.choices) > 0
assert not response.choices[0].message.content.endswith("</s>")
def test_streaming_with_stop_str(openai_client):
"""
Test non-streaming chat functionality with the local service
"""
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": True},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token == "</s>"
response = openai_client.chat.completions.create(
model="default",
messages=[{"role": "user", "content": "Hello, how are you?"}],
temperature=1,
max_tokens=5,
metadata={"include_stop_str_in_output": False},
stream=True,
)
# Assertions to check the response structure
last_token = ""
for chunk in response:
last_token = chunk.choices[0].delta.content
assert last_token != "</s>"