From edf1ca07afb47f37e2042489ce02249413bbe78a Mon Sep 17 00:00:00 2001
From: luukunn <83932082+luukunn@users.noreply.github.com>
Date: Fri, 15 Aug 2025 18:33:58 +0800
Subject: [PATCH] Feature/online/vs think 20250813 (#3440)
* add stream
* fix ernie_vl_reasoning_parsers
* fix bug
---
fastdeploy/entrypoints/openai/serving_chat.py | 28 ++--
.../entrypoints/openai/serving_completion.py | 65 +++++----
.../tool_parsers/ernie_x1_tool_parser.py | 88 +++++++----
fastdeploy/input/ernie_processor.py | 33 +++--
fastdeploy/input/ernie_vl_processor.py | 3 +-
fastdeploy/input/text_processor.py | 33 +++--
.../reasoning/ernie_vl_reasoning_parsers.py | 12 +-
.../reasoning/ernie_x1_reasoning_parsers.py | 138 ++++++------------
8 files changed, 206 insertions(+), 194 deletions(-)
diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py
index 1005bae0e..26d7597ad 100644
--- a/fastdeploy/entrypoints/openai/serving_chat.py
+++ b/fastdeploy/entrypoints/openai/serving_chat.py
@@ -252,22 +252,20 @@ class OpenAIServingChat:
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
)
- if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
- tool_delta_message = output["tool_delta_message"]
- if tool_delta_message is None:
- continue
- delta_message = tool_delta_message
- delta_message.reasoning_content = output.get("reasoning_content")
- if delta_message.tool_calls:
- tool_called = True
+ if not res["finished"]:
+ if "reasoning_delta_message" in output:
+ delta_message = output["reasoning_delta_message"]
+ elif "tool_delta_message" in output:
+ delta_message = output["tool_delta_message"]
+ if delta_message is not None and delta_message.tool_calls:
+ tool_called = True
+ else:
+ delta_message = DeltaMessage(content=delta_text)
else:
- delta_message = DeltaMessage(
- content=delta_text,
- reasoning_content=output.get("reasoning_content"),
- prompt_token_ids=None,
- completion_token_ids=None,
- tool_calls=None,
- )
+ delta_message = DeltaMessage(content=delta_text)
+
+ if delta_message is None:
+ continue
choice = ChatCompletionResponseStreamChoice(
index=0,
diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py
index 1e8ad0f86..00521341a 100644
--- a/fastdeploy/entrypoints/openai/serving_completion.py
+++ b/fastdeploy/entrypoints/openai/serving_completion.py
@@ -344,33 +344,46 @@ class OpenAIServingCompletion:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
output_tokens[idx] += 1
- if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
- tool_delta_message = output["tool_delta_message"]
- if tool_delta_message is None:
- continue
- delta_message = CompletionResponseStreamChoice(
- index=idx,
- text=output["text"],
- completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
- tool_calls=tool_delta_message.tool_calls,
- reasoning_content=output.get("reasoning_content"),
- arrival_time=arrival_time,
- logprobs=logprobs_res,
- )
- if tool_delta_message.tool_calls:
- tool_called = True
+ base_kwargs = {
+ "index": idx,
+ "completion_token_ids": output.get("token_ids") if request.return_token_ids else None,
+ "arrival_time": arrival_time,
+ "logprobs": logprobs_res,
+ }
+ delta_message_kwargs = None
+ if not res["finished"]:
+ if "reasoning_delta_message" in output:
+ reasoning_delta_message = output["reasoning_delta_message"]
+ if reasoning_delta_message is not None:
+ delta_message_kwargs = {
+ **base_kwargs,
+ "text": reasoning_delta_message.content or "",
+ "reasoning_content": reasoning_delta_message.reasoning_content,
+ }
+ elif "tool_delta_message" in output:
+ tool_delta_message = output["tool_delta_message"]
+ if tool_delta_message is not None:
+ delta_message_kwargs = {
+ **base_kwargs,
+ "text": tool_delta_message.content or "",
+ "tool_calls": tool_delta_message.tool_calls,
+ }
+ if tool_delta_message.tool_calls:
+ tool_called = True
+ else:
+ delta_message_kwargs = {
+ **base_kwargs,
+ "text": output["text"],
+ }
else:
- delta_message = CompletionResponseStreamChoice(
- index=idx,
- text=output["text"],
- prompt_token_ids=None,
- completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
- tool_calls=None,
- raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
- reasoning_content=output.get("reasoning_content"),
- arrival_time=arrival_time,
- logprobs=logprobs_res,
- )
+ delta_message_kwargs = {
+ **base_kwargs,
+ "text": output["text"],
+ }
+
+ if delta_message_kwargs is None:
+ continue
+ delta_message = CompletionResponseStreamChoice(**delta_message_kwargs)
choices.append(delta_message)
output_tokens[idx] += 1
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
index cec1f6840..d098ed521 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
@@ -57,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
+ self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+ if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
+ raise RuntimeError(
+ "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
+ )
if not self.model_tokenizer:
raise ValueError(
@@ -224,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: dict,
) -> Union[DeltaMessage, None]:
+
+ if self.tool_call_start_token_id not in current_token_ids:
+ return DeltaMessage(content=delta_text)
# 忽略空chunk
if len(delta_text.strip()) == 0:
return None
@@ -234,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
self.buffer += delta_text
# 处理增量中的新tool_call开始
- if "" in delta_text and "" not in previous_text:
+ if "" in delta_text:
self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
)
@@ -243,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
- # 增量解析逻辑
-
# 1. 尝试解析name字段
if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
@@ -271,46 +282,71 @@ class ErnieX1ToolParser(ToolParser):
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match:
args_content = args_match.group(1)
- # 处理多余的大括号
- open_braces = args_content.count("{")
- close_braces = args_content.count("}")
- if close_braces > open_braces:
- args_content = args_content[: args_content.rfind("}")]
+ print("args_content:", args_content)
try:
- # 增量解析arguments
- parsed_args = json.loads(args_content)
- if isinstance(parsed_args, dict):
- args_json = json.dumps(parsed_args, ensure_ascii=False)
- if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
- argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
+ # 检查是否到达arguments结尾(括号完全匹配)
+ if "}}" in args_content:
+ print("delta_text (partial):", delta_text)
+ # 逐个字符检查括号匹配状态
+ matched_pos = -1
+ for i, ch in enumerate(delta_text):
+ if ch == "{":
+ self.bracket_counts["total_l"] += 1
+ elif ch == "}":
+ self.bracket_counts["total_r"] += 1
+
+ if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
+ matched_pos = i
+ break
+
+ if matched_pos >= 0:
+ # 找到匹配点,清理buffer并返回
+ truncate_text = delta_text[: matched_pos + 1]
+ print("truncate_text:", truncate_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
- function=DeltaFunctionCall(arguments=argument_diff).model_dump(
+ function=DeltaFunctionCall(arguments=truncate_text).model_dump(
exclude_none=True
),
)
]
)
- print("delta argument:", delta)
- # 删除已处理部分
- processed_pos = args_match.start() + len('"arguments":')
- self.buffer = (
- self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
- )
- self.streamed_args_for_tool[self.current_tool_id] = args_json
+ self.buffer = self.buffer[args_match.end() :]
+ print(delta)
return delta
+ else:
+ # 没有完全匹配,继续累积
+ return None
+ else:
+ # 增量返回当前可解析的部分
+ for ch in delta_text:
+ if ch == "{":
+ self.bracket_counts["total_l"] += 1
+ elif ch == "}":
+ self.bracket_counts["total_r"] += 1
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ print("delta argument (partial):", delta)
+ print(
+ f"Current bracket counts - left: {self.bracket_counts['total_l']}, right: {self.bracket_counts['total_r']}"
+ )
+ return delta
except Exception as e:
- data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
-
+ data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
+ return None
if "" in self.buffer:
end_pos = self.buffer.find("")
self.buffer = self.buffer[end_pos + len("") :]
# 完成当前工具调用处理
- self.current_tool_id += 1
- self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return delta
diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py
index f91293bc5..69c2ddf63 100644
--- a/fastdeploy/input/ernie_processor.py
+++ b/fastdeploy/input/ernie_processor.py
@@ -50,8 +50,9 @@ class ErnieProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
- self.tool_parsers = dict()
+ self.tool_parser_dict = dict()
self.thinking_parser_dict = dict()
+ self.reasoning_end_dict = dict()
self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
@@ -291,10 +292,12 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text
- if self.reasoning_parser and (
- enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
+ if (
+ self.reasoning_parser
+ and req_id not in self.reasoning_end_dict
+ and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
):
- reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
+ reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
@@ -302,14 +305,13 @@ class ErnieProcessor(BaseDataProcessor):
previous_token_ids + token_ids,
token_ids,
)
- response_dict["outputs"]["text"] = text
- response_dict["outputs"]["reasoning_content"] = reasoning_content
- else:
- response_dict["outputs"]["text"] = delta_text
- if self.tool_parser_obj:
- if req_id not in self.tool_parsers:
- self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
- tool_parser = self.tool_parsers[req_id]
+ response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
+ if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
+ self.reasoning_end_dict[req_id] = True
+ if self.tool_parser_obj and req_id in self.reasoning_end_dict:
+ if req_id not in self.tool_parser_dict:
+ self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
+ tool_parser = self.tool_parser_dict[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
@@ -320,11 +322,14 @@ class ErnieProcessor(BaseDataProcessor):
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
+ response_dict["outputs"]["text"] = delta_text
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
- if req_id in self.tool_parsers:
- del self.tool_parsers[req_id]
+ if req_id in self.tool_parser_dict:
+ del self.tool_parser_dict[req_id]
+ if req_id in self.reasoning_end_dict:
+ del self.reasoning_end_dict[req_id]
return response_dict
def messages2ids(self, request_or_messages):
diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py
index 21a96e92b..65d877179 100644
--- a/fastdeploy/input/ernie_vl_processor.py
+++ b/fastdeploy/input/ernie_vl_processor.py
@@ -54,8 +54,9 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.image_patch_id = self.ernie_processor.image_patch_id
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
- self.tool_parsers = dict()
+ self.tool_parser_dict = dict()
self.decode_status = dict()
+ self.reasoning_end_dict = dict()
self._load_tokenizer()
self.eos_token_ids = [self.tokenizer.eos_token_id]
self.eos_token_id_len = len(self.eos_token_ids)
diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py
index 4bffee280..d9ccf0d20 100644
--- a/fastdeploy/input/text_processor.py
+++ b/fastdeploy/input/text_processor.py
@@ -168,7 +168,8 @@ class DataProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
- self.tool_parsers = dict()
+ self.tool_parser_dict = dict()
+ self.reasoning_end_dict = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -398,8 +399,12 @@ class DataProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text
- if enable_thinking and self.reasoning_parser:
- reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
+ if (
+ self.reasoning_parser
+ and req_id not in self.reasoning_end_dict
+ and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
+ ):
+ reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
@@ -407,14 +412,13 @@ class DataProcessor(BaseDataProcessor):
previous_token_ids + token_ids,
token_ids,
)
- response_dict["outputs"]["text"] = text
- response_dict["outputs"]["reasoning_content"] = reasoning_content
- else:
- response_dict["outputs"]["text"] = delta_text
- if self.tool_parser_obj and not is_end:
- if req_id not in self.tool_parsers:
- self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
- tool_parser = self.tool_parsers[req_id]
+ response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
+ if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
+ self.reasoning_end_dict[req_id] = True
+ if self.tool_parser_obj and req_id in self.reasoning_end_dict:
+ if req_id not in self.tool_parser_dict:
+ self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
+ tool_parser = self.tool_parser_dict[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
@@ -425,11 +429,14 @@ class DataProcessor(BaseDataProcessor):
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
+ response_dict["outputs"]["text"] = delta_text
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
- if req_id in self.tool_parsers:
- del self.tool_parsers[req_id]
+ if req_id in self.tool_parser_dict:
+ del self.tool_parser_dict[req_id]
+ if req_id in self.reasoning_end_dict:
+ del self.reasoning_end_dict[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):
diff --git a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py
index f5762b791..7702664d1 100644
--- a/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py
+++ b/fastdeploy/reasoning/ernie_vl_reasoning_parsers.py
@@ -65,18 +65,16 @@ class ErnieVLReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
- return "", ""
+ return None
if self.think_end_token_id in delta_token_ids:
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
- content = delta_text[end_index + len(self.end_token) :]
+ content = delta_text[end_index + len(self.end_token)]
+ return DeltaMessage(reasoning_content=reasoning_content, content=content)
elif self.think_end_token_id in previous_token_ids:
- reasoning_content = ""
- content = delta_text
+ return DeltaMessage(content=delta_text)
else:
- reasoning_content = delta_text
- content = ""
- return reasoning_content, content
+ return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
index 458505252..c75182b01 100644
--- a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
+++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
@@ -2,9 +2,9 @@
#
#
from collections.abc import Sequence
-from typing import Tuple
+from typing import Tuple, Union
-from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
#
@@ -47,6 +47,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
self.think_end_token_id = self.vocab.get("")
if self.think_end_token_id is None:
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
+ self.tool_call_start_token_id = self.vocab.get("")
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.tool_call_start_token_id in input_ids
def extract_reasoning_content_streaming(
self,
@@ -56,50 +60,63 @@ class ErnieX1ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
- ) -> tuple[str, str]:
+ ) -> Union[DeltaMessage, None]:
"""
根据用户需求实现的流式解析方法:
- 1. 初始内容都视为思考内容
+ 1. 初始内容都视为思考内容,返回delta_text,""
2. 当遇到\n时检查后续是否是
- 3. 思考结束后检查是还是
- 4. 对于内容,处理换行和结束标记
+ 3. 如果直接遇到也结束思考
+ 4. 思考结束后检查是还是
+ 5. 对于内容,处理各种边界条件
"""
- # 如果还在思考阶段
- if not previous_text.endswith(self.think_end_token):
- # 如果遇到\n后接或直接遇到,思考结束
- if (previous_text.endswith("\n") and delta_text == self.think_end_token) or (
- not previous_text.endswith("\n") and delta_text == self.think_end_token
- ):
- return "", ""
+ if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
+ return None
+ # 思考阶段处理
+ if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
+ # 如果遇到\n,暂时不返回,等待下一个delta_text
+ if delta_text == "\n":
+ return None
+ # 如果前一个是\n且当前是,结束思考
+ elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
+ return None
+ # 如果直接遇到也结束思考
+ elif delta_text.startswith(self.think_end_token):
+ return None
# 否则继续返回思考内容
- return delta_text, ""
+ return DeltaMessage(reasoning_content=delta_text)
# 思考结束后检查是tool_call还是response
remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
-
- # 跳过think后的换行
- after_think = after_think.lstrip("\n")
+ after_think = after_think.lstrip("\n") # 跳过think后的换行
# 处理tool_call情况
if after_think.startswith(self.tool_call_start_token):
- return "", ""
+ return None
# 处理response情况
if after_think.startswith(self.response_start_token):
- response_content = after_think[len(self.response_start_token) :]
- # 跳过response后的换行
- response_content = response_content.lstrip("\n")
-
- # 检查response是否结束
- if response_content.endswith(self.response_end_token):
- return "", ""
-
- # 返回response内容(使用delta_text确保流式输出)
- return "", delta_text
+ # 遇到标签时不立即返回
+ if delta_text == self.response_start_token:
+ return None
+ # 遇到后的换行符也不立即返回
+ elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
+ return None
+ # 处理回复内容中的换行符
+ if delta_text == "\n":
+ return None
+ # 如果前一个是\n且当前是,结束回复
+ elif previous_text.endswith("\n") and delta_text == self.response_end_token:
+ return None
+ # 如果直接遇到也结束回复
+ elif delta_text == self.response_end_token:
+ return None
+ # 其他情况返回实际内容
+ else:
+ return DeltaMessage(content=delta_text)
# 默认情况不返回内容
- return "", ""
+ return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
"""
@@ -143,66 +160,3 @@ class ErnieX1ReasoningParser(ReasoningParser):
reasoning_content = model_output
response_content = ""
return reasoning_content, response_content
-
-
-import unittest
-from unittest.mock import MagicMock
-
-
-class TestErnieX1ReasoningParser(unittest.TestCase):
- def setUp(self):
- self.tokenizer = MagicMock()
- self.tokenizer.vocab = {
- "\n\n\n": 1001,
- "\n": 1002,
- "\n\n": 1003,
- "\n": 1004,
- "\n\n": 1005,
- }
- self.parser = ErnieX1ReasoningParser(self.tokenizer)
-
- def test_streaming_with_think_and_response(self):
- # 测试标准情况:\n\n\n\ncontent\n\n
- prev_text = "thinking"
- delta_text = "\n\n\n\nanswer\n\n"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
- self.assertEqual(result, ("thinking", "answer"))
-
- def test_streaming_with_think_and_tool_call(self):
- # 测试tool_call情况
- prev_text = "thinking"
- delta_text = "\n\n\n\ndetails\n\n"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
- self.assertEqual(result, ("thinking", ""))
-
- def test_streaming_with_think_no_newline(self):
- # 测试没有前置换行的情况
- prev_text = "thinking"
- delta_text = "\n\nanswer\n"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
- self.assertEqual(result, ("thinking", "answer"))
-
- def test_streaming_response_without_leading_newline(self):
- # 测试response内容没有前置换行
- prev_text = "thinking\n\n\n"
- delta_text = "answer\n\n"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
- self.assertEqual(result, ("thinking", "answer"))
-
- def test_streaming_response_with_middle_newline(self):
- # 测试response内容中间的换行符
- prev_text = "thinking\n\n\n\n"
- delta_text = "line1\nline2\n\n"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
- self.assertEqual(result, ("thinking", "line1\nline2"))
-
- def test_streaming_partial_response(self):
- # 测试不完整的response流式输出
- prev_text = "thinking\n\n\n\n"
- delta_text = "partial answer"
- result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
- self.assertEqual(result, ("thinking", "partial answer"))
-
-
-if __name__ == "__main__":
- unittest.main()