mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Feature] Support include_stop_str_in_output (#2919)
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -104,6 +104,7 @@ class OpenAIServingChat:
|
|||||||
num_choices = 1
|
num_choices = 1
|
||||||
max_streaming_response_tokens = 1
|
max_streaming_response_tokens = 1
|
||||||
enable_thinking = None
|
enable_thinking = None
|
||||||
|
include_stop_str_in_output = False
|
||||||
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
|
if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1:
|
||||||
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
|
max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"]
|
||||||
|
|
||||||
@@ -152,8 +153,9 @@ class OpenAIServingChat:
|
|||||||
raise ValueError("{}".format(res["error_msg"]))
|
raise ValueError("{}".format(res["error_msg"]))
|
||||||
if request.metadata is not None:
|
if request.metadata is not None:
|
||||||
enable_thinking = request.metadata.get("enable_thinking")
|
enable_thinking = request.metadata.get("enable_thinking")
|
||||||
|
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||||
self.engine_client.data_processor.process_response_dict(
|
self.engine_client.data_processor.process_response_dict(
|
||||||
res, stream=True, enable_thinking=enable_thinking)
|
res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||||
|
|
||||||
if res['metrics']['first_token_time'] is not None:
|
if res['metrics']['first_token_time'] is not None:
|
||||||
arrival_time = res['metrics']['first_token_time']
|
arrival_time = res['metrics']['first_token_time']
|
||||||
@@ -282,6 +284,7 @@ class OpenAIServingChat:
|
|||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
final_res = None
|
final_res = None
|
||||||
enable_thinking = None
|
enable_thinking = None
|
||||||
|
include_stop_str_in_output = False
|
||||||
try:
|
try:
|
||||||
dealer = await aiozmq.create_zmq_stream(
|
dealer = await aiozmq.create_zmq_stream(
|
||||||
zmq.DEALER,
|
zmq.DEALER,
|
||||||
@@ -312,8 +315,9 @@ class OpenAIServingChat:
|
|||||||
raise ValueError("{}".format(data["error_msg"]))
|
raise ValueError("{}".format(data["error_msg"]))
|
||||||
if request.metadata is not None:
|
if request.metadata is not None:
|
||||||
enable_thinking = request.metadata.get("enable_thinking")
|
enable_thinking = request.metadata.get("enable_thinking")
|
||||||
|
include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False)
|
||||||
data = self.engine_client.data_processor.process_response_dict(
|
data = self.engine_client.data_processor.process_response_dict(
|
||||||
data, stream=False, enable_thinking=enable_thinking)
|
data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output)
|
||||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||||
# The logprob for handling the response
|
# The logprob for handling the response
|
||||||
|
@@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
if request.prompt_token_ids is None or len(
|
if request.prompt_token_ids is None or len(
|
||||||
request.prompt_token_ids) == 0:
|
request.prompt_token_ids) == 0:
|
||||||
system = request.get("system")
|
|
||||||
if request.prompt is None and request.messages is None:
|
if request.prompt is None and request.messages is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||||
@@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
request['stop_token_ids'] = stop_seqs
|
request['stop_token_ids'] = stop_seqs
|
||||||
request['stop_seqs_len'] = stop_seqs_len
|
request['stop_seqs_len'] = stop_seqs_len
|
||||||
|
|
||||||
system = request.get("system")
|
|
||||||
# 处理prompt_token_ids
|
# 处理prompt_token_ids
|
||||||
if not request.get('prompt_token_ids'):
|
if not request.get('prompt_token_ids'):
|
||||||
if request.get('prompt') is None and request.get(
|
if request.get('prompt') is None and request.get(
|
||||||
@@ -249,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
if is_end and len(token_ids) > 0:
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
@@ -284,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
|
|||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
|
|
||||||
if is_end and len(token_ids) > 0:
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||||
|
@@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
is_end = response_dict["finished"]
|
is_end = response_dict["finished"]
|
||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
if is_end and len(token_ids) > 0:
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||||
@@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
req_id = response_dict["request_id"]
|
req_id = response_dict["request_id"]
|
||||||
token_ids = response_dict["outputs"]["token_ids"]
|
token_ids = response_dict["outputs"]["token_ids"]
|
||||||
|
|
||||||
if is_end and len(token_ids) > 0:
|
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||||
token_ids = token_ids[:-1]
|
token_ids = token_ids[:-1]
|
||||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||||
@@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
response_dict, enable_thinking=enable_thinking, **kwargs)
|
response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||||
else:
|
else:
|
||||||
return self.process_response_dict_normal(
|
return self.process_response_dict_normal(
|
||||||
response_dict=response_dict, enable_thinking=enable_thinking)
|
response_dict=response_dict, enable_thinking=enable_thinking, **kwargs)
|
||||||
|
|
||||||
def text2ids(self, text, max_model_len, raw_request=True):
|
def text2ids(self, text, max_model_len, raw_request=True):
|
||||||
"""
|
"""
|
||||||
|
@@ -12,15 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
import subprocess
|
|
||||||
import socket
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
# Read ports from environment variables; use default values if not set
|
# Read ports from environment variables; use default values if not set
|
||||||
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
FD_API_PORT = int(os.getenv("FD_API_PORT", 8188))
|
||||||
@@ -313,4 +314,66 @@ def test_streaming(openai_client, capsys):
|
|||||||
output = []
|
output = []
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
output.append(chunk.choices[0].text)
|
output.append(chunk.choices[0].text)
|
||||||
assert len(output) > 0
|
assert len(output) > 0
|
||||||
|
|
||||||
|
def test_non_streaming_with_stop_str(openai_client):
|
||||||
|
"""
|
||||||
|
Test non-streaming chat functionality with the local service
|
||||||
|
"""
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=5,
|
||||||
|
metadata={"include_stop_str_in_output": True},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
# Assertions to check the response structure
|
||||||
|
assert hasattr(response, 'choices')
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert response.choices[0].message.content.endswith("</s>")
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=5,
|
||||||
|
metadata={"include_stop_str_in_output": False},
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
# Assertions to check the response structure
|
||||||
|
assert hasattr(response, 'choices')
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert not response.choices[0].message.content.endswith("</s>")
|
||||||
|
|
||||||
|
def test_streaming_with_stop_str(openai_client):
|
||||||
|
"""
|
||||||
|
Test non-streaming chat functionality with the local service
|
||||||
|
"""
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=5,
|
||||||
|
metadata={"include_stop_str_in_output": True},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# Assertions to check the response structure
|
||||||
|
last_token = ""
|
||||||
|
for chunk in response:
|
||||||
|
last_token = chunk.choices[0].delta.content
|
||||||
|
assert last_token == "</s>"
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
temperature=1,
|
||||||
|
max_tokens=5,
|
||||||
|
metadata={"include_stop_str_in_output": False},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
# Assertions to check the response structure
|
||||||
|
last_token = ""
|
||||||
|
for chunk in response:
|
||||||
|
last_token = chunk.choices[0].delta.content
|
||||||
|
assert last_token != "</s>"
|
||||||
|
Reference in New Issue
Block a user