diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
index 9b0c7b9cb..14a784f17 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
@@ -1,3 +1,4 @@
+"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
@@ -11,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
import json
import re
@@ -97,29 +99,29 @@ class ErnieX1ToolParser(ToolParser):
remaining_text = model_output
while True:
- # 查找下一个tool_call块
+ # Find the next
tool_call_pos = remaining_text.find("")
if tool_call_pos == -1:
break
- # 提取tool_call开始位置后的内容
+ # Extract content after
tool_content_start = tool_call_pos + len("")
tool_content_end = remaining_text.find("", tool_content_start)
tool_json = ""
if tool_content_end == -1:
- # 处理未闭合的tool_call块(截断情况)
+ # Processing unclosed tool_call block (truncated case)
tool_json = remaining_text[tool_content_start:].strip()
- remaining_text = "" # 没有更多内容需要处理
+ remaining_text = "" # No more content to process
else:
- # 处理完整的tool_call块
+ # Processing closed block
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
remaining_text = remaining_text[tool_content_end + len("") :]
if not tool_json:
continue
- # 处理JSON内容
+ # Process tool_json
tool_json = tool_json.strip()
if not tool_json.startswith("{"):
tool_json = "{" + tool_json
@@ -127,7 +129,7 @@ class ErnieX1ToolParser(ToolParser):
tool_json = tool_json + "}"
try:
- # 首先尝试标准JSON解析
+ # Parsing strategy: First try standard json.loads
try:
tool_data = json.loads(tool_json)
@@ -136,26 +138,26 @@ class ErnieX1ToolParser(ToolParser):
{
"name": tool_data["name"],
"arguments": tool_data["arguments"],
- "_is_complete": True, # 明确标记为完整解析
+ "_is_complete": True, # Mark as complete
}
)
continue
except json.JSONDecodeError:
pass
- # 标准解析失败时尝试partial_json_parser
+ # Try partial_json_parser when standard parsing fails
from partial_json_parser.core.options import Allow
try:
tool_data = {}
flags = Allow.ALL & ~Allow.STR
- # 解析name字段
+ # Parse the name field
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
if name_match:
tool_data["name"] = name_match.group(1)
- # 解析arguments字段
+ # Parse the arguments field
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
if args_match:
try:
@@ -168,7 +170,7 @@ class ErnieX1ToolParser(ToolParser):
{
"name": tool_data.get("name", ""),
"arguments": tool_data.get("arguments", {}),
- "_is_partial": True, # 标记为部分解析
+ "_is_partial": True, # Mark as partial
}
)
except Exception as e:
@@ -183,18 +185,18 @@ class ErnieX1ToolParser(ToolParser):
return ExtractedToolCallInformation(tools_called=False, content=model_output)
tool_calls = []
- all_complete = True # 初始设为True,只要有一个不完整就变为False
+ all_complete = True # Initialize as all complete
for tool_call in function_call_arr:
- # 记录工具调用解析状态
+ # Set flags
is_complete = tool_call.get("_is_complete", False)
is_partial = tool_call.get("_is_partial", False)
- # 只要有一个不完整就认为整体不完整
+ # If any tool call is incomplete or partial, mark all_complete as False
if not is_complete or is_partial:
all_complete = False
- # 处理参数序列化
+ # Process arguments
tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict):
tool_args = {}
@@ -215,7 +217,7 @@ class ErnieX1ToolParser(ToolParser):
)
)
- # 只有当所有工具调用都明确标记为complete时才返回tools_called=True
+ # Only return tools_called=True if all tool calls are complete
return ExtractedToolCallInformation(
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
)
@@ -237,16 +239,16 @@ class ErnieX1ToolParser(ToolParser):
if self.tool_call_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
- # 忽略空chunk
+ # Skip empty chunks
if len(delta_text.strip()) == 0:
return None
try:
delta = None
- # 使用buffer累积delta_text内容
+ # Use buffer to accumulate delta_text content
self.buffer += delta_text
- # 处理增量中的新tool_call开始
+ # Process the buffer content
if "" in delta_text:
self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
@@ -256,7 +258,7 @@ class ErnieX1ToolParser(ToolParser):
self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
- # 1. 尝试解析name字段
+ # 1. Try to parse the name field
if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
if name_match:
@@ -272,19 +274,18 @@ class ErnieX1ToolParser(ToolParser):
)
]
)
- # 删除已处理的name部分
+ # Delete the processed name part from the buffer
self.buffer = self.buffer[name_match.end() :]
self.current_tool_name_sent = True
return delta
- # 2. 尝试解析arguments字段
+ # 2. Processing arguments field
if '"arguments"' in self.buffer:
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match:
args_content = args_match.group(1)
try:
- # 检查是否到达arguments结尾(括号完全匹配)
+ # Check if arguments field is complete by bracket matching
if "}}" in args_content:
- # 逐个字符检查括号匹配状态
matched_pos = -1
for i, ch in enumerate(delta_text):
if ch == "{":
@@ -292,12 +293,12 @@ class ErnieX1ToolParser(ToolParser):
elif ch == "}":
self.bracket_counts["total_r"] += 1
- if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
+ if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
matched_pos = i
break
if matched_pos >= 0:
- # 找到匹配点,清理buffer并返回
+ # Clean up bracket counts for next tool call
truncate_text = delta_text[: matched_pos + 1]
delta = DeltaMessage(
tool_calls=[
@@ -312,10 +313,10 @@ class ErnieX1ToolParser(ToolParser):
self.buffer = self.buffer[args_match.end() :]
return delta
else:
- # 没有完全匹配,继续累积
+ # No complete match yet
return None
else:
- # 增量返回当前可解析的部分
+ # Return partial arguments
for ch in delta_text:
if ch == "{":
self.bracket_counts["total_l"] += 1
@@ -337,7 +338,6 @@ class ErnieX1ToolParser(ToolParser):
end_pos = self.buffer.find("")
self.buffer = self.buffer[end_pos + len("") :]
- # 完成当前工具调用处理
self.streamed_args_for_tool.append("")
return delta
diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
index c75182b01..3aa23aee9 100644
--- a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
+++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
@@ -1,36 +1,19 @@
-# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
-#
-#
from collections.abc import Sequence
from typing import Tuple, Union
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
-#
-#
-# Licensed under the Apache License, Version 2.0 (the "License"
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
@ReasoningParserManager.register_module("ernie_x1")
class ErnieX1ReasoningParser(ReasoningParser):
"""
Reasoning parser for ernie_x1 model with stricter boundary checking.
- This implementation follows the user's proposed approach:
- 1. For thinking content: waits for \n then checks for tag
- 2. For response content: checks for tag first, then waits for \n
- 3. Handles newlines in content more precisely
+ Unified rules:
+ - Do not strip newline before
+ - Do not strip newline after
+ - Do not strip newline before
"""
def __init__(self, tokenizer):
@@ -49,9 +32,6 @@ class ErnieX1ReasoningParser(ReasoningParser):
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
self.tool_call_start_token_id = self.vocab.get("")
- def is_reasoning_end(self, input_ids: list[int]) -> bool:
- return self.tool_call_start_token_id in input_ids
-
def extract_reasoning_content_streaming(
self,
previous_text: str,
@@ -61,102 +41,68 @@ class ErnieX1ReasoningParser(ReasoningParser):
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
- """
- 根据用户需求实现的流式解析方法:
- 1. 初始内容都视为思考内容,返回delta_text,""
- 2. 当遇到\n时检查后续是否是
- 3. 如果直接遇到也结束思考
- 4. 思考结束后检查是还是
- 5. 对于内容,处理各种边界条件
- """
+ # Ignore the single token
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
- # 思考阶段处理
+
+ # --- Thinking stage handling ---
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
- # 如果遇到\n,暂时不返回,等待下一个delta_text
- if delta_text == "\n":
+ # If delta is , stop thinking, do not return
+ if delta_text.startswith(self.think_end_token):
return None
- # 如果前一个是\n且当前是,结束思考
- elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
- return None
- # 如果直接遇到也结束思考
- elif delta_text.startswith(self.think_end_token):
- return None
- # 否则继续返回思考内容
+ # Otherwise, return thinking content (keep \n as-is)
return DeltaMessage(reasoning_content=delta_text)
- # 思考结束后检查是tool_call还是response
+ # --- After thinking ends, check tool_call or response ---
remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
- after_think = after_think.lstrip("\n") # 跳过think后的换行
+ after_think = after_think.lstrip("\n")
- # 处理tool_call情况
+ # Handle tool_call case: skip it
if after_think.startswith(self.tool_call_start_token):
return None
- # 处理response情况
+ # Handle response case
if after_think.startswith(self.response_start_token):
- # 遇到标签时不立即返回
+ # Do not return when tag itself appears
if delta_text == self.response_start_token:
return None
- # 遇到后的换行符也不立即返回
- elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
- return None
- # 处理回复内容中的换行符
- if delta_text == "\n":
- return None
- # 如果前一个是\n且当前是,结束回复
- elif previous_text.endswith("\n") and delta_text == self.response_end_token:
- return None
- # 如果直接遇到也结束回复
+ # Do not return itself
elif delta_text == self.response_end_token:
return None
- # 其他情况返回实际内容
+ # Otherwise, return response content (keep \n as-is)
else:
return DeltaMessage(content=delta_text)
- # 默认情况不返回内容
+ # Default case: return nothing
return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
- """
- Batch version of the enhanced parser.
- Modified to preserve newlines in both reasoning and response content,
- only removing the single newline before closing tags.
- """
reasoning_content = ""
response_content = ""
think_end_pos = model_output.find(self.think_end_token)
if think_end_pos != -1:
- # Extract thinking content - only remove the last newline before
reasoning_content = model_output[:think_end_pos]
- if think_end_pos > 0 and reasoning_content[-1] == "\n":
- reasoning_content = reasoning_content[:-1]
remaining = model_output[think_end_pos + len(self.think_end_token) :]
- # Skip newlines after
- remaining = remaining.lstrip("\n")
+ # find or
+ response_pos = remaining.find(self.response_start_token)
+ tool_pos = remaining.find(self.tool_call_start_token)
- # Check for response or tool_call
- if remaining.startswith(self.response_start_token):
- response_pos = len(self.response_start_token)
- remaining = remaining[response_pos:].lstrip("\n")
- response_end_pos = remaining.find(self.response_end_token)
+ # first
+ if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
+ # The content after the response_start position
+ remaining_response = remaining[response_pos + len(self.response_start_token) :]
+ response_end_pos = remaining_response.find(self.response_end_token)
if response_end_pos != -1:
- # Only strip the last newline before , not all
- if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
- response_content = remaining[: response_end_pos - 1]
- else:
- response_content = remaining[:response_end_pos]
+ response_content = remaining_response[:response_end_pos]
else:
- # If no found, return the rest as response content
- response_content = remaining
- elif remaining.startswith(self.tool_call_start_token):
- pass # No response content
+ response_content = remaining_response
+ # The content after the response_start position is tool_call
else:
- # No thinking content found, return the whole input as reasoning
reasoning_content = model_output
response_content = ""
+
return reasoning_content, response_content
diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py
new file mode 100644
index 000000000..e818801d9
--- /dev/null
+++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py
@@ -0,0 +1,141 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import unittest
+from unittest.mock import patch
+
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
+from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import (
+ ErnieX1ToolParser,
+)
+
+
+class DummyTokenizer:
+ """Dummy tokenizer with minimal vocab for testing"""
+
+ def __init__(self):
+ self.vocab = {"": 1, "": 2}
+
+
+class TestErnieX1ToolParser(unittest.TestCase):
+ def setUp(self):
+ class DummyTokenizer:
+ def __init__(self):
+ self.vocab = {"": 1, "": 2}
+
+ def get_vocab(self):
+ return self.vocab
+
+ self.tokenizer = DummyTokenizer()
+ self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer)
+ self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
+
+ # ---------------- Batch extraction tests ----------------
+
+ def test_extract_tool_calls_complete(self):
+ """Test normal extraction of complete tool_call JSON"""
+ output = '{"name": "get_weather", "arguments": {"location": "北京"}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_partial_arguments(self):
+ """Test partial extraction when arguments incomplete"""
+ output = '{"name": "get_weather", "arguments": {"location": "北"'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_invalid_response_before_toolcall(self):
+ """Test case where before is invalid"""
+ output = 'hello{"name": "get_weather", "arguments": {}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertIn("", result.content)
+
+ def test_extract_tool_calls_no_toolcall(self):
+ """Test when no tool_call tags are present"""
+ output = "no tool call here"
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+
+ def test_extract_tool_calls_invalid_json(self):
+ """Test tool_call with badly formatted JSON triggers fallback parser"""
+ output = '"name": "get_weather", "arguments": {'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_exception(self):
+ """Force exception to cover error branch"""
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
+ ):
+ output = '{"name": "get_weather", "arguments": {}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+
+ # ---------------- Streaming extraction tests ----------------
+
+ def test_streaming_no_toolcall(self):
+ """Streaming extraction returns normal DeltaMessage when no """
+ result = self.parser.extract_tool_calls_streaming(
+ "", "abc", "abc", [], [], [], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.content, "abc")
+
+ def test_streaming_skip_empty_chunk(self):
+ """Streaming extraction skips empty chunks"""
+ result = self.parser.extract_tool_calls_streaming(
+ "", "", " ", [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_new_toolcall_and_name(self):
+ """Streaming extraction detects new toolcall and extracts name"""
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "", '{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsNotNone(delta)
+ self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
+
+ def test_streaming_partial_arguments(self):
+ """Streaming extraction yields partial arguments deltas"""
+ text = '"arguments": {"location":'
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "" + text, text, [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(delta, DeltaMessage)
+ self.assertIn("arguments", delta.tool_calls[0].function.arguments)
+
+ def test_streaming_complete_arguments_and_end(self):
+ """Streaming extraction completes arguments with brackets matched and closes tool_call"""
+ text = '"arguments": {"location": "北京"}}'
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "" + text, text, [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(delta, DeltaMessage)
+ # Also simulate closing tag
+ end_delta = self.parser.extract_tool_calls_streaming(
+ "", "", "", [], [2], [2], self.dummy_request.model_dump()
+ )
+ self.assertIsNotNone(end_delta)
+ self.assertEqual(end_delta.content, "")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/reasoning/test_reasoning_parser.py b/tests/reasoning/test_reasoning_parser.py
new file mode 100644
index 000000000..90a48c899
--- /dev/null
+++ b/tests/reasoning/test_reasoning_parser.py
@@ -0,0 +1,265 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import unittest
+
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
+from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
+from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
+
+
+class DummyTokenizer:
+ """Minimal tokenizer with vocab for testing."""
+
+ def __init__(self):
+ self.vocab = {
+ "": 100,
+ "": 101,
+ "": 102,
+ "": 103,
+ "": 104,
+ }
+
+ def get_vocab(self):
+ """Return vocab dict for testing."""
+ return self.vocab
+
+
+class TestReasoningParser(ReasoningParser):
+ def is_reasoning_end(self, input_ids):
+ """
+ Return True to simulate end of reasoning content.
+ """
+ return True
+
+ def extract_content_ids(self, input_ids):
+ """
+ Return input_ids directly for testing.
+ """
+ return input_ids
+
+ def extract_reasoning_content(self, model_output, request):
+ """
+ Used for testing non-streaming extraction.
+ """
+ return model_output, model_output
+
+ def extract_reasoning_content_streaming(
+ self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids
+ ):
+ """
+ Return None for streaming extraction; minimal implementation for testing.
+ """
+ return None
+
+
+class TestReasoningParserManager(unittest.TestCase):
+ """
+ Unit tests for ReasoningParserManager functionality.
+ """
+
+ def setUp(self):
+ """
+ Save original registry to restore after each test.
+ """
+ self.original_parsers = ReasoningParserManager.reasoning_parsers.copy()
+
+ def tearDown(self):
+ """
+ Restore original registry to avoid test pollution.
+ """
+ ReasoningParserManager.reasoning_parsers = self.original_parsers.copy()
+
+ def test_register_and_get_parser(self):
+ """
+ Test that a parser can be registered and retrieved successfully.
+ Verifies normal registration and retrieval functionality.
+ """
+ ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser", force=True)
+ parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser")
+ self.assertIs(parser_cls, TestReasoningParser)
+
+ def test_register_duplicate_without_force_raises(self):
+ """
+ Test that registering a parser with an existing name without force raises KeyError.
+ Ensures duplicate registrations are handled correctly.
+ """
+ ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True)
+ with self.assertRaises(KeyError):
+ ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False)
+
+ def test_register_non_subclass_raises(self):
+ """
+ Test that registering a class not inheriting from ReasoningParser raises TypeError.
+ Ensures type safety for registered modules.
+ """
+
+ class NotParser:
+ pass
+
+ with self.assertRaises(TypeError):
+ ReasoningParserManager.register_module(module=NotParser, name="not_parser")
+
+ def test_get_unregistered_parser_raises(self):
+ """
+ Test that retrieving a parser that was not registered raises KeyError.
+ Ensures get_reasoning_parser handles unknown names correctly.
+ """
+ with self.assertRaises(KeyError):
+ ReasoningParserManager.get_reasoning_parser("nonexistent_parser")
+
+
+class TestErnieX1ReasoningParser(unittest.TestCase):
+ def setUp(self):
+ self.parser = ErnieX1ReasoningParser(DummyTokenizer())
+ self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
+ self.tokenizer = DummyTokenizer()
+
+ # ---- Streaming parsing ----
+ def test_streaming_thinking_content(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="a",
+ delta_text="a",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[200],
+ )
+ self.assertEqual(msg.reasoning_content, "a")
+
+ def test_streaming_thinking_newline_preserved(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abc\n",
+ delta_text="\n",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[201],
+ )
+ self.assertEqual(msg.reasoning_content, "\n")
+
+ def test_streaming_thinking_end_tag(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abc",
+ delta_text="",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[self.parser.think_end_token_id],
+ )
+ self.assertIsNone(msg)
+
+ def test_streaming_response_content(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="h",
+ delta_text="h",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[202],
+ )
+ self.assertEqual(msg.content, "h")
+
+ def test_streaming_response_newline_preserved(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="hi",
+ current_text="hi\n",
+ delta_text="\n",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[203],
+ )
+ self.assertEqual(msg.content, "\n")
+
+ def test_streaming_response_ignore_tags(self):
+ self.assertIsNone(
+ self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="",
+ delta_text="",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[self.parser.vocab[""]],
+ )
+ )
+
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="\n",
+ delta_text="\n",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[204],
+ )
+ self.assertIsInstance(msg, DeltaMessage)
+ self.assertEqual(msg.content, "\n")
+
+ self.assertIsNone(
+ self.parser.extract_reasoning_content_streaming(
+ previous_text="\n",
+ current_text="\n",
+ delta_text="",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[self.parser.vocab[""]],
+ )
+ )
+
+ def test_streaming_tool_call(self):
+ msg = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="",
+ delta_text="",
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[self.parser.vocab[""]],
+ )
+ self.assertIsNone(msg)
+
+ # ---- Batch parsing ----
+ def test_batch_reasoning_and_response(self):
+ text = "abc\n\nhello\nworld"
+ reasoning, response = self.parser.extract_reasoning_content(text, self.request)
+ self.assertEqual(reasoning, "abc\n")
+ self.assertEqual(response, "hello\nworld")
+
+ def test_batch_reasoning_and_tool_call(self):
+ text = "abccall_here"
+ reasoning, response = self.parser.extract_reasoning_content(text, self.request)
+ self.assertEqual(reasoning, "abc")
+ self.assertEqual(response, "")
+
+ def test_batch_no_thinking_tag(self):
+ text = "no_thinking_here"
+ reasoning, response = self.parser.extract_reasoning_content(text, self.request)
+ self.assertEqual(reasoning, "no_thinking_here")
+ self.assertEqual(response, "")
+
+ def test_batch_response_without_end_tag(self):
+ text = "abcpartial response"
+ reasoning, response = self.parser.extract_reasoning_content(text, self.request)
+ self.assertEqual(reasoning, "abc")
+ self.assertEqual(response, "partial response")
+
+ def test_batch_preserve_all_newlines(self):
+ text = "abc\n\nline1\nline2\n"
+ reasoning, response = self.parser.extract_reasoning_content(text, self.request)
+ self.assertEqual(reasoning, "abc\n")
+ self.assertEqual(response, "line1\nline2\n")
+
+
+if __name__ == "__main__":
+ unittest.main()