[Fix] X1 reasoning parser , skip parsing of \n around special tokens (#4241)

This commit is contained in:
zhuzixuan
2025-09-24 17:04:59 +08:00
committed by GitHub
parent d40a1046de
commit dc600010de
4 changed files with 466 additions and 114 deletions

View File

@@ -1,3 +1,4 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
@@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
import json import json
import re import re
@@ -97,29 +99,29 @@ class ErnieX1ToolParser(ToolParser):
remaining_text = model_output remaining_text = model_output
while True: while True:
# 查找下一个tool_call # Find the next <tool_call>
tool_call_pos = remaining_text.find("<tool_call>") tool_call_pos = remaining_text.find("<tool_call>")
if tool_call_pos == -1: if tool_call_pos == -1:
break break
# 提取tool_call开始位置后的内容 # Extract content after <tool_call>
tool_content_start = tool_call_pos + len("<tool_call>") tool_content_start = tool_call_pos + len("<tool_call>")
tool_content_end = remaining_text.find("</tool_call>", tool_content_start) tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
tool_json = "" tool_json = ""
if tool_content_end == -1: if tool_content_end == -1:
# 处理未闭合的tool_call块截断情况 # Processing unclosed tool_call block (truncated case)
tool_json = remaining_text[tool_content_start:].strip() tool_json = remaining_text[tool_content_start:].strip()
remaining_text = "" # 没有更多内容需要处理 remaining_text = "" # No more content to process
else: else:
# 处理完整的tool_call # Processing closed </tool_call> block
tool_json = remaining_text[tool_content_start:tool_content_end].strip() tool_json = remaining_text[tool_content_start:tool_content_end].strip()
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :] remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
if not tool_json: if not tool_json:
continue continue
# 处理JSON内容 # Process tool_json
tool_json = tool_json.strip() tool_json = tool_json.strip()
if not tool_json.startswith("{"): if not tool_json.startswith("{"):
tool_json = "{" + tool_json tool_json = "{" + tool_json
@@ -127,7 +129,7 @@ class ErnieX1ToolParser(ToolParser):
tool_json = tool_json + "}" tool_json = tool_json + "}"
try: try:
# 首先尝试标准JSON解析 # Parsing strategy: First try standard json.loads
try: try:
tool_data = json.loads(tool_json) tool_data = json.loads(tool_json)
@@ -136,26 +138,26 @@ class ErnieX1ToolParser(ToolParser):
{ {
"name": tool_data["name"], "name": tool_data["name"],
"arguments": tool_data["arguments"], "arguments": tool_data["arguments"],
"_is_complete": True, # 明确标记为完整解析 "_is_complete": True, # Mark as complete
} }
) )
continue continue
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
# 标准解析失败时尝试partial_json_parser # Try partial_json_parser when standard parsing fails
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
try: try:
tool_data = {} tool_data = {}
flags = Allow.ALL & ~Allow.STR flags = Allow.ALL & ~Allow.STR
# 解析name字段 # Parse the name field
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
if name_match: if name_match:
tool_data["name"] = name_match.group(1) tool_data["name"] = name_match.group(1)
# 解析arguments字段 # Parse the arguments field
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
if args_match: if args_match:
try: try:
@@ -168,7 +170,7 @@ class ErnieX1ToolParser(ToolParser):
{ {
"name": tool_data.get("name", ""), "name": tool_data.get("name", ""),
"arguments": tool_data.get("arguments", {}), "arguments": tool_data.get("arguments", {}),
"_is_partial": True, # 标记为部分解析 "_is_partial": True, # Mark as partial
} }
) )
except Exception as e: except Exception as e:
@@ -183,18 +185,18 @@ class ErnieX1ToolParser(ToolParser):
return ExtractedToolCallInformation(tools_called=False, content=model_output) return ExtractedToolCallInformation(tools_called=False, content=model_output)
tool_calls = [] tool_calls = []
all_complete = True # 初始设为True只要有一个不完整就变为False all_complete = True # Initialize as all complete
for tool_call in function_call_arr: for tool_call in function_call_arr:
# 记录工具调用解析状态 # Set flags
is_complete = tool_call.get("_is_complete", False) is_complete = tool_call.get("_is_complete", False)
is_partial = tool_call.get("_is_partial", False) is_partial = tool_call.get("_is_partial", False)
# 只要有一个不完整就认为整体不完整 # If any tool call is incomplete or partial, mark all_complete as False
if not is_complete or is_partial: if not is_complete or is_partial:
all_complete = False all_complete = False
# 处理参数序列化 # Process arguments
tool_args = tool_call.get("arguments", {}) tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict): if not isinstance(tool_args, dict):
tool_args = {} tool_args = {}
@@ -215,7 +217,7 @@ class ErnieX1ToolParser(ToolParser):
) )
) )
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True # Only return tools_called=True if all tool calls are complete
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
) )
@@ -237,16 +239,16 @@ class ErnieX1ToolParser(ToolParser):
if self.tool_call_start_token_id not in current_token_ids: if self.tool_call_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
# 忽略空chunk # Skip empty chunks
if len(delta_text.strip()) == 0: if len(delta_text.strip()) == 0:
return None return None
try: try:
delta = None delta = None
# 使用buffer累积delta_text内容 # Use buffer to accumulate delta_text content
self.buffer += delta_text self.buffer += delta_text
# 处理增量中的新tool_call开始 # Process the buffer content
if "<tool_call>" in delta_text: if "<tool_call>" in delta_text:
self.current_tool_id = ( self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
@@ -256,7 +258,7 @@ class ErnieX1ToolParser(ToolParser):
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
# 1. 尝试解析name字段 # 1. Try to parse the name field
if not self.current_tool_name_sent and '"name"' in self.buffer: if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer) name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
if name_match: if name_match:
@@ -272,19 +274,18 @@ class ErnieX1ToolParser(ToolParser):
) )
] ]
) )
# 删除已处理的name部分 # Delete the processed name part from the buffer
self.buffer = self.buffer[name_match.end() :] self.buffer = self.buffer[name_match.end() :]
self.current_tool_name_sent = True self.current_tool_name_sent = True
return delta return delta
# 2. 尝试解析arguments字段 # 2. Processing arguments field
if '"arguments"' in self.buffer: if '"arguments"' in self.buffer:
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer) args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match: if args_match:
args_content = args_match.group(1) args_content = args_match.group(1)
try: try:
# 检查是否到达arguments结尾(括号完全匹配) # Check if arguments field is complete by bracket matching
if "}}" in args_content: if "}}" in args_content:
# 逐个字符检查括号匹配状态
matched_pos = -1 matched_pos = -1
for i, ch in enumerate(delta_text): for i, ch in enumerate(delta_text):
if ch == "{": if ch == "{":
@@ -292,12 +293,12 @@ class ErnieX1ToolParser(ToolParser):
elif ch == "}": elif ch == "}":
self.bracket_counts["total_r"] += 1 self.bracket_counts["total_r"] += 1
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配 if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
matched_pos = i matched_pos = i
break break
if matched_pos >= 0: if matched_pos >= 0:
# 找到匹配点清理buffer并返回 # Clean up bracket counts for next tool call
truncate_text = delta_text[: matched_pos + 1] truncate_text = delta_text[: matched_pos + 1]
delta = DeltaMessage( delta = DeltaMessage(
tool_calls=[ tool_calls=[
@@ -312,10 +313,10 @@ class ErnieX1ToolParser(ToolParser):
self.buffer = self.buffer[args_match.end() :] self.buffer = self.buffer[args_match.end() :]
return delta return delta
else: else:
# 没有完全匹配,继续累积 # No complete match yet
return None return None
else: else:
# 增量返回当前可解析的部分 # Return partial arguments
for ch in delta_text: for ch in delta_text:
if ch == "{": if ch == "{":
self.bracket_counts["total_l"] += 1 self.bracket_counts["total_l"] += 1
@@ -337,7 +338,6 @@ class ErnieX1ToolParser(ToolParser):
end_pos = self.buffer.find("</tool_call>") end_pos = self.buffer.find("</tool_call>")
self.buffer = self.buffer[end_pos + len("</tool_call>") :] self.buffer = self.buffer[end_pos + len("</tool_call>") :]
# 完成当前工具调用处理
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
return delta return delta

View File

@@ -1,36 +1,19 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
#
from collections.abc import Sequence from collections.abc import Sequence
from typing import Tuple, Union from typing import Tuple, Union
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
#
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@ReasoningParserManager.register_module("ernie_x1") @ReasoningParserManager.register_module("ernie_x1")
class ErnieX1ReasoningParser(ReasoningParser): class ErnieX1ReasoningParser(ReasoningParser):
""" """
Reasoning parser for ernie_x1 model with stricter boundary checking. Reasoning parser for ernie_x1 model with stricter boundary checking.
This implementation follows the user's proposed approach: Unified rules:
1. For thinking content: waits for \n then checks for </think> tag - Do not strip newline before </think>
2. For response content: checks for <response> tag first, then waits for \n - Do not strip newline after <response>
3. Handles newlines in content more precisely - Do not strip newline before </response>
""" """
def __init__(self, tokenizer): def __init__(self, tokenizer):
@@ -49,9 +32,6 @@ class ErnieX1ReasoningParser(ReasoningParser):
raise RuntimeError("Could not find think end token id in tokenizer vocabulary") raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
self.tool_call_start_token_id = self.vocab.get("<tool_call>") self.tool_call_start_token_id = self.vocab.get("<tool_call>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.tool_call_start_token_id in input_ids
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
@@ -61,102 +41,68 @@ class ErnieX1ReasoningParser(ReasoningParser):
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
""" # Ignore the single </think> token
根据用户需求实现的流式解析方法:
1. 初始内容都视为思考内容返回delta_text,""
2. 当遇到\n时检查后续是否是</think>
3. 如果直接遇到</think>也结束思考
4. 思考结束后检查是<response>还是<tool_call>
5. 对于<response>内容,处理各种边界条件
"""
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None return None
# 思考阶段处理
# --- Thinking stage handling ---
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text: if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
# 如果遇到\n暂时不返回等待下一个delta_text # If delta is </think>, stop thinking, do not return
if delta_text == "\n": if delta_text.startswith(self.think_end_token):
return None return None
# 如果前一个是\n且当前是</think>,结束思考 # Otherwise, return thinking content (keep \n as-is)
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
return None
# 如果直接遇到</think>也结束思考
elif delta_text.startswith(self.think_end_token):
return None
# 否则继续返回思考内容
return DeltaMessage(reasoning_content=delta_text) return DeltaMessage(reasoning_content=delta_text)
# 思考结束后检查是tool_call还是response # --- After thinking ends, check tool_call or response ---
remaining_text = previous_text + delta_text remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :] after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
after_think = after_think.lstrip("\n") # 跳过think后的换行 after_think = after_think.lstrip("\n")
# 处理tool_call情况 # Handle tool_call case: skip it
if after_think.startswith(self.tool_call_start_token): if after_think.startswith(self.tool_call_start_token):
return None return None
# 处理response情况 # Handle response case
if after_think.startswith(self.response_start_token): if after_think.startswith(self.response_start_token):
# 遇到<response>标签时不立即返回 # Do not return when <response> tag itself appears
if delta_text == self.response_start_token: if delta_text == self.response_start_token:
return None return None
# 遇到<response>后的换行符也不立即返回 # Do not return </response> itself
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
return None
# 处理回复内容中的换行符
if delta_text == "\n":
return None
# 如果前一个是\n且当前是</response>,结束回复
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
return None
# 如果直接遇到</response>也结束回复
elif delta_text == self.response_end_token: elif delta_text == self.response_end_token:
return None return None
# 其他情况返回实际内容 # Otherwise, return response content (keep \n as-is)
else: else:
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
# 默认情况不返回内容 # Default case: return nothing
return None return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]: def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
"""
Batch version of the enhanced parser.
Modified to preserve newlines in both reasoning and response content,
only removing the single newline before closing tags.
"""
reasoning_content = "" reasoning_content = ""
response_content = "" response_content = ""
think_end_pos = model_output.find(self.think_end_token) think_end_pos = model_output.find(self.think_end_token)
if think_end_pos != -1: if think_end_pos != -1:
# Extract thinking content - only remove the last newline before </think>
reasoning_content = model_output[:think_end_pos] reasoning_content = model_output[:think_end_pos]
if think_end_pos > 0 and reasoning_content[-1] == "\n":
reasoning_content = reasoning_content[:-1]
remaining = model_output[think_end_pos + len(self.think_end_token) :] remaining = model_output[think_end_pos + len(self.think_end_token) :]
# Skip newlines after </think> # find <response> or <tool>
remaining = remaining.lstrip("\n") response_pos = remaining.find(self.response_start_token)
tool_pos = remaining.find(self.tool_call_start_token)
# Check for response or tool_call # <response> first
if remaining.startswith(self.response_start_token): if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
response_pos = len(self.response_start_token) # The content after the response_start position
remaining = remaining[response_pos:].lstrip("\n") remaining_response = remaining[response_pos + len(self.response_start_token) :]
response_end_pos = remaining.find(self.response_end_token) response_end_pos = remaining_response.find(self.response_end_token)
if response_end_pos != -1: if response_end_pos != -1:
# Only strip the last newline before </response>, not all response_content = remaining_response[:response_end_pos]
if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
response_content = remaining[: response_end_pos - 1]
else:
response_content = remaining[:response_end_pos]
else: else:
# If no </response> found, return the rest as response content response_content = remaining_response
response_content = remaining # The content after the response_start position is tool_call
elif remaining.startswith(self.tool_call_start_token):
pass # No response content
else: else:
# No thinking content found, return the whole input as reasoning
reasoning_content = model_output reasoning_content = model_output
response_content = "" response_content = ""
return reasoning_content, response_content return reasoning_content, response_content

View File

@@ -0,0 +1,141 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import patch
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import (
ErnieX1ToolParser,
)
class DummyTokenizer:
"""Dummy tokenizer with minimal vocab for testing"""
def __init__(self):
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
class TestErnieX1ToolParser(unittest.TestCase):
def setUp(self):
class DummyTokenizer:
def __init__(self):
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
def get_vocab(self):
return self.vocab
self.tokenizer = DummyTokenizer()
self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer)
self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
# ---------------- Batch extraction tests ----------------
def test_extract_tool_calls_complete(self):
"""Test normal extraction of complete tool_call JSON"""
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertTrue(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_partial_arguments(self):
"""Test partial extraction when arguments incomplete"""
output = '<tool_call>{"name": "get_weather", "arguments": {"location": ""</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_invalid_response_before_toolcall(self):
"""Test case where <response> before <tool_call> is invalid"""
output = '<response>hello</response><tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertIn("<response>", result.content)
def test_extract_tool_calls_no_toolcall(self):
"""Test when no tool_call tags are present"""
output = "no tool call here"
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
def test_extract_tool_calls_invalid_json(self):
"""Test tool_call with badly formatted JSON triggers fallback parser"""
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_exception(self):
"""Force exception to cover error branch"""
with patch(
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
):
output = '<tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
# ---------------- Streaming extraction tests ----------------
def test_streaming_no_toolcall(self):
"""Streaming extraction returns normal DeltaMessage when no <tool_call>"""
result = self.parser.extract_tool_calls_streaming(
"", "abc", "abc", [], [], [], self.dummy_request.model_dump()
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "abc")
def test_streaming_skip_empty_chunk(self):
"""Streaming extraction skips empty chunks"""
result = self.parser.extract_tool_calls_streaming(
"", "<tool_call>", " ", [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsNone(result)
def test_streaming_new_toolcall_and_name(self):
"""Streaming extraction detects new toolcall and extracts name"""
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>", '<tool_call>{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsNotNone(delta)
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
def test_streaming_partial_arguments(self):
"""Streaming extraction yields partial arguments deltas"""
text = '"arguments": {"location":'
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsInstance(delta, DeltaMessage)
self.assertIn("arguments", delta.tool_calls[0].function.arguments)
def test_streaming_complete_arguments_and_end(self):
"""Streaming extraction completes arguments with brackets matched and closes tool_call"""
text = '"arguments": {"location": "北京"}}'
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsInstance(delta, DeltaMessage)
# Also simulate closing tag
end_delta = self.parser.extract_tool_calls_streaming(
"", "</tool_call>", "</tool_call>", [], [2], [2], self.dummy_request.model_dump()
)
self.assertIsNotNone(end_delta)
self.assertEqual(end_delta.content, "</tool_call>")
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,265 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
class DummyTokenizer:
"""Minimal tokenizer with vocab for testing."""
def __init__(self):
self.vocab = {
"</think>": 100,
"<tool_call>": 101,
"</tool_call>": 102,
"<response>": 103,
"</response>": 104,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class TestReasoningParser(ReasoningParser):
def is_reasoning_end(self, input_ids):
"""
Return True to simulate end of reasoning content.
"""
return True
def extract_content_ids(self, input_ids):
"""
Return input_ids directly for testing.
"""
return input_ids
def extract_reasoning_content(self, model_output, request):
"""
Used for testing non-streaming extraction.
"""
return model_output, model_output
def extract_reasoning_content_streaming(
self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids
):
"""
Return None for streaming extraction; minimal implementation for testing.
"""
return None
class TestReasoningParserManager(unittest.TestCase):
"""
Unit tests for ReasoningParserManager functionality.
"""
def setUp(self):
"""
Save original registry to restore after each test.
"""
self.original_parsers = ReasoningParserManager.reasoning_parsers.copy()
def tearDown(self):
"""
Restore original registry to avoid test pollution.
"""
ReasoningParserManager.reasoning_parsers = self.original_parsers.copy()
def test_register_and_get_parser(self):
"""
Test that a parser can be registered and retrieved successfully.
Verifies normal registration and retrieval functionality.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser", force=True)
parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser")
self.assertIs(parser_cls, TestReasoningParser)
def test_register_duplicate_without_force_raises(self):
"""
Test that registering a parser with an existing name without force raises KeyError.
Ensures duplicate registrations are handled correctly.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True)
with self.assertRaises(KeyError):
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False)
def test_register_non_subclass_raises(self):
"""
Test that registering a class not inheriting from ReasoningParser raises TypeError.
Ensures type safety for registered modules.
"""
class NotParser:
pass
with self.assertRaises(TypeError):
ReasoningParserManager.register_module(module=NotParser, name="not_parser")
def test_get_unregistered_parser_raises(self):
"""
Test that retrieving a parser that was not registered raises KeyError.
Ensures get_reasoning_parser handles unknown names correctly.
"""
with self.assertRaises(KeyError):
ReasoningParserManager.get_reasoning_parser("nonexistent_parser")
class TestErnieX1ReasoningParser(unittest.TestCase):
def setUp(self):
self.parser = ErnieX1ReasoningParser(DummyTokenizer())
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
self.tokenizer = DummyTokenizer()
# ---- Streaming parsing ----
def test_streaming_thinking_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
)
self.assertEqual(msg.reasoning_content, "a")
def test_streaming_thinking_newline_preserved(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[201],
)
self.assertEqual(msg.reasoning_content, "\n")
def test_streaming_thinking_end_tag(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc</think>",
delta_text="</think>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.think_end_token_id],
)
self.assertIsNone(msg)
def test_streaming_response_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="</think><response>",
current_text="</think><response>h",
delta_text="h",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[202],
)
self.assertEqual(msg.content, "h")
def test_streaming_response_newline_preserved(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="</think><response>hi",
current_text="</think><response>hi\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[203],
)
self.assertEqual(msg.content, "\n")
def test_streaming_response_ignore_tags(self):
self.assertIsNone(
self.parser.extract_reasoning_content_streaming(
previous_text="</think>",
current_text="</think><response>",
delta_text="<response>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["<response>"]],
)
)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="</think><response>",
current_text="</think><response>\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[204],
)
self.assertIsInstance(msg, DeltaMessage)
self.assertEqual(msg.content, "\n")
self.assertIsNone(
self.parser.extract_reasoning_content_streaming(
previous_text="</think><response>\n",
current_text="</think><response>\n</response>",
delta_text="</response>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["</response>"]],
)
)
def test_streaming_tool_call(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="</think>",
current_text="</think><tool_call>",
delta_text="<tool_call>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab["<tool_call>"]],
)
self.assertIsNone(msg)
# ---- Batch parsing ----
def test_batch_reasoning_and_response(self):
text = "abc\n</think>\n<response>hello\nworld</response>"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "hello\nworld")
def test_batch_reasoning_and_tool_call(self):
text = "abc</think><tool_call>call_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "")
def test_batch_no_thinking_tag(self):
text = "no_thinking_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
self.assertEqual(reasoning, "no_thinking_here")
self.assertEqual(response, "")
def test_batch_response_without_end_tag(self):
text = "abc</think><response>partial response"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "partial response")
def test_batch_preserve_all_newlines(self):
text = "abc\n</think>\n<response>line1\nline2\n</response>"
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "line1\nline2\n")
if __name__ == "__main__":
unittest.main()