From e421d5100123e6cfec6c6f2ba7a549acbee2db00 Mon Sep 17 00:00:00 2001 From: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> Date: Fri, 18 Jul 2025 19:43:19 +0800 Subject: [PATCH] [Feature] Support include_stop_str_in_output (#2919) Co-authored-by: Jiang-Jia-Jun --- fastdeploy/entrypoints/openai/serving_chat.py | 8 +- fastdeploy/input/ernie_processor.py | 6 +- fastdeploy/input/text_processor.py | 6 +- test/ci_use/EB_Lite/test_EB_Lite_serving.py | 75 +++++++++++++++++-- 4 files changed, 80 insertions(+), 15 deletions(-) diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 35bff4ec8..31359e728 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -104,6 +104,7 @@ class OpenAIServingChat: num_choices = 1 max_streaming_response_tokens = 1 enable_thinking = None + include_stop_str_in_output = False if request.metadata is not None and request.metadata.get("max_streaming_response_tokens", 1) > 1: max_streaming_response_tokens = request.metadata["max_streaming_response_tokens"] @@ -152,8 +153,9 @@ class OpenAIServingChat: raise ValueError("{}".format(res["error_msg"])) if request.metadata is not None: enable_thinking = request.metadata.get("enable_thinking") + include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False) self.engine_client.data_processor.process_response_dict( - res, stream=True, enable_thinking=enable_thinking) + res, stream=True, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output) if res['metrics']['first_token_time'] is not None: arrival_time = res['metrics']['first_token_time'] @@ -282,6 +284,7 @@ class OpenAIServingChat: created_time = int(time.time()) final_res = None enable_thinking = None + include_stop_str_in_output = False try: dealer = await aiozmq.create_zmq_stream( zmq.DEALER, @@ -312,8 +315,9 @@ class OpenAIServingChat: raise ValueError("{}".format(data["error_msg"])) if request.metadata is not None: enable_thinking = request.metadata.get("enable_thinking") + include_stop_str_in_output = request.metadata.get("include_stop_str_in_output", False) data = self.engine_client.data_processor.process_response_dict( - data, stream=False, enable_thinking=enable_thinking) + data, stream=False, enable_thinking=enable_thinking, include_stop_str_in_output=include_stop_str_in_output) # api_server_logger.debug(f"Client {request_id} received: {data}") previous_num_tokens += len(data["outputs"]["token_ids"]) # The logprob for handling the response diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index fa241b95b..0867214fe 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor): if request.prompt_token_ids is None or len( request.prompt_token_ids) == 0: - system = request.get("system") if request.prompt is None and request.messages is None: raise ValueError( f"The request should have `input_ids`, `text` or `messages`: {request}.") @@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor): request['stop_token_ids'] = stop_seqs request['stop_seqs_len'] = stop_seqs_len - system = request.get("system") # 处理prompt_token_ids if not request.get('prompt_token_ids'): if request.get('prompt') is None and request.get( @@ -249,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor): token_ids = response_dict["outputs"]["token_ids"] is_end = response_dict["finished"] req_id = response_dict["request_id"] - if is_end and len(token_ids) > 0: + if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) @@ -284,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor): req_id = response_dict["request_id"] token_ids = response_dict["outputs"]["token_ids"] - if is_end and len(token_ids) > 0: + if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens( diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 9d30dee3e..9c3c615c4 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -355,7 +355,7 @@ class DataProcessor(BaseDataProcessor): token_ids = response_dict["outputs"]["token_ids"] is_end = response_dict["finished"] req_id = response_dict["request_id"] - if is_end and len(token_ids) > 0: + if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) @@ -390,7 +390,7 @@ class DataProcessor(BaseDataProcessor): req_id = response_dict["request_id"] token_ids = response_dict["outputs"]["token_ids"] - if is_end and len(token_ids) > 0: + if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"): if token_ids[-1] == self.tokenizer.eos_token_id: token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens( @@ -430,7 +430,7 @@ class DataProcessor(BaseDataProcessor): response_dict, enable_thinking=enable_thinking, **kwargs) else: return self.process_response_dict_normal( - response_dict=response_dict, enable_thinking=enable_thinking) + response_dict=response_dict, enable_thinking=enable_thinking, **kwargs) def text2ids(self, text, max_model_len, raw_request=True): """ diff --git a/test/ci_use/EB_Lite/test_EB_Lite_serving.py b/test/ci_use/EB_Lite/test_EB_Lite_serving.py index d0b9e6dd6..042ba8e12 100644 --- a/test/ci_use/EB_Lite/test_EB_Lite_serving.py +++ b/test/ci_use/EB_Lite/test_EB_Lite_serving.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import requests -import time -import subprocess -import socket import os import signal +import socket +import subprocess import sys +import time + import openai +import pytest +import requests # Read ports from environment variables; use default values if not set FD_API_PORT = int(os.getenv("FD_API_PORT", 8188)) @@ -313,4 +314,66 @@ def test_streaming(openai_client, capsys): output = [] for chunk in response: output.append(chunk.choices[0].text) - assert len(output) > 0 \ No newline at end of file + assert len(output) > 0 + +def test_non_streaming_with_stop_str(openai_client): + """ + Test non-streaming chat functionality with the local service + """ + response = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + max_tokens=5, + metadata={"include_stop_str_in_output": True}, + stream=False, + ) + # Assertions to check the response structure + assert hasattr(response, 'choices') + assert len(response.choices) > 0 + assert response.choices[0].message.content.endswith("") + + response = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + max_tokens=5, + metadata={"include_stop_str_in_output": False}, + stream=False, + ) + # Assertions to check the response structure + assert hasattr(response, 'choices') + assert len(response.choices) > 0 + assert not response.choices[0].message.content.endswith("") + +def test_streaming_with_stop_str(openai_client): + """ + Test non-streaming chat functionality with the local service + """ + response = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + max_tokens=5, + metadata={"include_stop_str_in_output": True}, + stream=True, + ) + # Assertions to check the response structure + last_token = "" + for chunk in response: + last_token = chunk.choices[0].delta.content + assert last_token == "" + + response = openai_client.chat.completions.create( + model="default", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=1, + max_tokens=5, + metadata={"include_stop_str_in_output": False}, + stream=True, + ) + # Assertions to check the response structure + last_token = "" + for chunk in response: + last_token = chunk.choices[0].delta.content + assert last_token != ""