mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[Fix] X1 reasoning parser , skip parsing of \n around special tokens (#4241)
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
"""
|
||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
@@ -11,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@@ -97,29 +99,29 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
remaining_text = model_output
|
remaining_text = model_output
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# 查找下一个tool_call块
|
# Find the next <tool_call>
|
||||||
tool_call_pos = remaining_text.find("<tool_call>")
|
tool_call_pos = remaining_text.find("<tool_call>")
|
||||||
if tool_call_pos == -1:
|
if tool_call_pos == -1:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 提取tool_call开始位置后的内容
|
# Extract content after <tool_call>
|
||||||
tool_content_start = tool_call_pos + len("<tool_call>")
|
tool_content_start = tool_call_pos + len("<tool_call>")
|
||||||
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
|
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
|
||||||
|
|
||||||
tool_json = ""
|
tool_json = ""
|
||||||
if tool_content_end == -1:
|
if tool_content_end == -1:
|
||||||
# 处理未闭合的tool_call块(截断情况)
|
# Processing unclosed tool_call block (truncated case)
|
||||||
tool_json = remaining_text[tool_content_start:].strip()
|
tool_json = remaining_text[tool_content_start:].strip()
|
||||||
remaining_text = "" # 没有更多内容需要处理
|
remaining_text = "" # No more content to process
|
||||||
else:
|
else:
|
||||||
# 处理完整的tool_call块
|
# Processing closed </tool_call> block
|
||||||
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
|
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
|
||||||
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
|
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
|
||||||
|
|
||||||
if not tool_json:
|
if not tool_json:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理JSON内容
|
# Process tool_json
|
||||||
tool_json = tool_json.strip()
|
tool_json = tool_json.strip()
|
||||||
if not tool_json.startswith("{"):
|
if not tool_json.startswith("{"):
|
||||||
tool_json = "{" + tool_json
|
tool_json = "{" + tool_json
|
||||||
@@ -127,7 +129,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
tool_json = tool_json + "}"
|
tool_json = tool_json + "}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 首先尝试标准JSON解析
|
# Parsing strategy: First try standard json.loads
|
||||||
try:
|
try:
|
||||||
tool_data = json.loads(tool_json)
|
tool_data = json.loads(tool_json)
|
||||||
|
|
||||||
@@ -136,26 +138,26 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
{
|
{
|
||||||
"name": tool_data["name"],
|
"name": tool_data["name"],
|
||||||
"arguments": tool_data["arguments"],
|
"arguments": tool_data["arguments"],
|
||||||
"_is_complete": True, # 明确标记为完整解析
|
"_is_complete": True, # Mark as complete
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 标准解析失败时尝试partial_json_parser
|
# Try partial_json_parser when standard parsing fails
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_data = {}
|
tool_data = {}
|
||||||
flags = Allow.ALL & ~Allow.STR
|
flags = Allow.ALL & ~Allow.STR
|
||||||
|
|
||||||
# 解析name字段
|
# Parse the name field
|
||||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
|
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
|
||||||
if name_match:
|
if name_match:
|
||||||
tool_data["name"] = name_match.group(1)
|
tool_data["name"] = name_match.group(1)
|
||||||
|
|
||||||
# 解析arguments字段
|
# Parse the arguments field
|
||||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
|
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
|
||||||
if args_match:
|
if args_match:
|
||||||
try:
|
try:
|
||||||
@@ -168,7 +170,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
{
|
{
|
||||||
"name": tool_data.get("name", ""),
|
"name": tool_data.get("name", ""),
|
||||||
"arguments": tool_data.get("arguments", {}),
|
"arguments": tool_data.get("arguments", {}),
|
||||||
"_is_partial": True, # 标记为部分解析
|
"_is_partial": True, # Mark as partial
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -183,18 +185,18 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
all_complete = True # 初始设为True,只要有一个不完整就变为False
|
all_complete = True # Initialize as all complete
|
||||||
|
|
||||||
for tool_call in function_call_arr:
|
for tool_call in function_call_arr:
|
||||||
# 记录工具调用解析状态
|
# Set flags
|
||||||
is_complete = tool_call.get("_is_complete", False)
|
is_complete = tool_call.get("_is_complete", False)
|
||||||
is_partial = tool_call.get("_is_partial", False)
|
is_partial = tool_call.get("_is_partial", False)
|
||||||
|
|
||||||
# 只要有一个不完整就认为整体不完整
|
# If any tool call is incomplete or partial, mark all_complete as False
|
||||||
if not is_complete or is_partial:
|
if not is_complete or is_partial:
|
||||||
all_complete = False
|
all_complete = False
|
||||||
|
|
||||||
# 处理参数序列化
|
# Process arguments
|
||||||
tool_args = tool_call.get("arguments", {})
|
tool_args = tool_call.get("arguments", {})
|
||||||
if not isinstance(tool_args, dict):
|
if not isinstance(tool_args, dict):
|
||||||
tool_args = {}
|
tool_args = {}
|
||||||
@@ -215,7 +217,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True
|
# Only return tools_called=True if all tool calls are complete
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
|
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
|
||||||
)
|
)
|
||||||
@@ -237,16 +239,16 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
|
|
||||||
if self.tool_call_start_token_id not in current_token_ids:
|
if self.tool_call_start_token_id not in current_token_ids:
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
# 忽略空chunk
|
# Skip empty chunks
|
||||||
if len(delta_text.strip()) == 0:
|
if len(delta_text.strip()) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
delta = None
|
delta = None
|
||||||
# 使用buffer累积delta_text内容
|
# Use buffer to accumulate delta_text content
|
||||||
self.buffer += delta_text
|
self.buffer += delta_text
|
||||||
|
|
||||||
# 处理增量中的新tool_call开始
|
# Process the buffer content
|
||||||
if "<tool_call>" in delta_text:
|
if "<tool_call>" in delta_text:
|
||||||
self.current_tool_id = (
|
self.current_tool_id = (
|
||||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||||
@@ -256,7 +258,7 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
self.streamed_args_for_tool.append("")
|
self.streamed_args_for_tool.append("")
|
||||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||||
|
|
||||||
# 1. 尝试解析name字段
|
# 1. Try to parse the name field
|
||||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||||
if name_match:
|
if name_match:
|
||||||
@@ -272,19 +274,18 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# 删除已处理的name部分
|
# Delete the processed name part from the buffer
|
||||||
self.buffer = self.buffer[name_match.end() :]
|
self.buffer = self.buffer[name_match.end() :]
|
||||||
self.current_tool_name_sent = True
|
self.current_tool_name_sent = True
|
||||||
return delta
|
return delta
|
||||||
# 2. 尝试解析arguments字段
|
# 2. Processing arguments field
|
||||||
if '"arguments"' in self.buffer:
|
if '"arguments"' in self.buffer:
|
||||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||||
if args_match:
|
if args_match:
|
||||||
args_content = args_match.group(1)
|
args_content = args_match.group(1)
|
||||||
try:
|
try:
|
||||||
# 检查是否到达arguments结尾(括号完全匹配)
|
# Check if arguments field is complete by bracket matching
|
||||||
if "}}" in args_content:
|
if "}}" in args_content:
|
||||||
# 逐个字符检查括号匹配状态
|
|
||||||
matched_pos = -1
|
matched_pos = -1
|
||||||
for i, ch in enumerate(delta_text):
|
for i, ch in enumerate(delta_text):
|
||||||
if ch == "{":
|
if ch == "{":
|
||||||
@@ -292,12 +293,12 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
elif ch == "}":
|
elif ch == "}":
|
||||||
self.bracket_counts["total_r"] += 1
|
self.bracket_counts["total_r"] += 1
|
||||||
|
|
||||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
|
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
|
||||||
matched_pos = i
|
matched_pos = i
|
||||||
break
|
break
|
||||||
|
|
||||||
if matched_pos >= 0:
|
if matched_pos >= 0:
|
||||||
# 找到匹配点,清理buffer并返回
|
# Clean up bracket counts for next tool call
|
||||||
truncate_text = delta_text[: matched_pos + 1]
|
truncate_text = delta_text[: matched_pos + 1]
|
||||||
delta = DeltaMessage(
|
delta = DeltaMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -312,10 +313,10 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
self.buffer = self.buffer[args_match.end() :]
|
self.buffer = self.buffer[args_match.end() :]
|
||||||
return delta
|
return delta
|
||||||
else:
|
else:
|
||||||
# 没有完全匹配,继续累积
|
# No complete match yet
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
# 增量返回当前可解析的部分
|
# Return partial arguments
|
||||||
for ch in delta_text:
|
for ch in delta_text:
|
||||||
if ch == "{":
|
if ch == "{":
|
||||||
self.bracket_counts["total_l"] += 1
|
self.bracket_counts["total_l"] += 1
|
||||||
@@ -337,7 +338,6 @@ class ErnieX1ToolParser(ToolParser):
|
|||||||
end_pos = self.buffer.find("</tool_call>")
|
end_pos = self.buffer.find("</tool_call>")
|
||||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||||
|
|
||||||
# 完成当前工具调用处理
|
|
||||||
self.streamed_args_for_tool.append("")
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
return delta
|
return delta
|
||||||
|
@@ -1,36 +1,19 @@
|
|||||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
@ReasoningParserManager.register_module("ernie_x1")
|
@ReasoningParserManager.register_module("ernie_x1")
|
||||||
class ErnieX1ReasoningParser(ReasoningParser):
|
class ErnieX1ReasoningParser(ReasoningParser):
|
||||||
"""
|
"""
|
||||||
Reasoning parser for ernie_x1 model with stricter boundary checking.
|
Reasoning parser for ernie_x1 model with stricter boundary checking.
|
||||||
|
|
||||||
This implementation follows the user's proposed approach:
|
Unified rules:
|
||||||
1. For thinking content: waits for \n then checks for </think> tag
|
- Do not strip newline before </think>
|
||||||
2. For response content: checks for <response> tag first, then waits for \n
|
- Do not strip newline after <response>
|
||||||
3. Handles newlines in content more precisely
|
- Do not strip newline before </response>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tokenizer):
|
def __init__(self, tokenizer):
|
||||||
@@ -49,9 +32,6 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
|||||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
||||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
|
||||||
return self.tool_call_start_token_id in input_ids
|
|
||||||
|
|
||||||
def extract_reasoning_content_streaming(
|
def extract_reasoning_content_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@@ -61,102 +41,68 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
|||||||
current_token_ids: Sequence[int],
|
current_token_ids: Sequence[int],
|
||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
) -> Union[DeltaMessage, None]:
|
) -> Union[DeltaMessage, None]:
|
||||||
"""
|
# Ignore the single </think> token
|
||||||
根据用户需求实现的流式解析方法:
|
|
||||||
1. 初始内容都视为思考内容,返回delta_text,""
|
|
||||||
2. 当遇到\n时检查后续是否是</think>
|
|
||||||
3. 如果直接遇到</think>也结束思考
|
|
||||||
4. 思考结束后检查是<response>还是<tool_call>
|
|
||||||
5. 对于<response>内容,处理各种边界条件
|
|
||||||
"""
|
|
||||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||||
return None
|
return None
|
||||||
# 思考阶段处理
|
|
||||||
|
# --- Thinking stage handling ---
|
||||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
||||||
# 如果遇到\n,暂时不返回,等待下一个delta_text
|
# If delta is </think>, stop thinking, do not return
|
||||||
if delta_text == "\n":
|
if delta_text.startswith(self.think_end_token):
|
||||||
return None
|
return None
|
||||||
# 如果前一个是\n且当前是</think>,结束思考
|
# Otherwise, return thinking content (keep \n as-is)
|
||||||
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
|
|
||||||
return None
|
|
||||||
# 如果直接遇到</think>也结束思考
|
|
||||||
elif delta_text.startswith(self.think_end_token):
|
|
||||||
return None
|
|
||||||
# 否则继续返回思考内容
|
|
||||||
return DeltaMessage(reasoning_content=delta_text)
|
return DeltaMessage(reasoning_content=delta_text)
|
||||||
|
|
||||||
# 思考结束后检查是tool_call还是response
|
# --- After thinking ends, check tool_call or response ---
|
||||||
remaining_text = previous_text + delta_text
|
remaining_text = previous_text + delta_text
|
||||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
||||||
after_think = after_think.lstrip("\n") # 跳过think后的换行
|
after_think = after_think.lstrip("\n")
|
||||||
|
|
||||||
# 处理tool_call情况
|
# Handle tool_call case: skip it
|
||||||
if after_think.startswith(self.tool_call_start_token):
|
if after_think.startswith(self.tool_call_start_token):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 处理response情况
|
# Handle response case
|
||||||
if after_think.startswith(self.response_start_token):
|
if after_think.startswith(self.response_start_token):
|
||||||
# 遇到<response>标签时不立即返回
|
# Do not return when <response> tag itself appears
|
||||||
if delta_text == self.response_start_token:
|
if delta_text == self.response_start_token:
|
||||||
return None
|
return None
|
||||||
# 遇到<response>后的换行符也不立即返回
|
# Do not return </response> itself
|
||||||
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
|
|
||||||
return None
|
|
||||||
# 处理回复内容中的换行符
|
|
||||||
if delta_text == "\n":
|
|
||||||
return None
|
|
||||||
# 如果前一个是\n且当前是</response>,结束回复
|
|
||||||
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
|
|
||||||
return None
|
|
||||||
# 如果直接遇到</response>也结束回复
|
|
||||||
elif delta_text == self.response_end_token:
|
elif delta_text == self.response_end_token:
|
||||||
return None
|
return None
|
||||||
# 其他情况返回实际内容
|
# Otherwise, return response content (keep \n as-is)
|
||||||
else:
|
else:
|
||||||
return DeltaMessage(content=delta_text)
|
return DeltaMessage(content=delta_text)
|
||||||
|
|
||||||
# 默认情况不返回内容
|
# Default case: return nothing
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
||||||
"""
|
|
||||||
Batch version of the enhanced parser.
|
|
||||||
Modified to preserve newlines in both reasoning and response content,
|
|
||||||
only removing the single newline before closing tags.
|
|
||||||
"""
|
|
||||||
reasoning_content = ""
|
reasoning_content = ""
|
||||||
response_content = ""
|
response_content = ""
|
||||||
|
|
||||||
think_end_pos = model_output.find(self.think_end_token)
|
think_end_pos = model_output.find(self.think_end_token)
|
||||||
if think_end_pos != -1:
|
if think_end_pos != -1:
|
||||||
# Extract thinking content - only remove the last newline before </think>
|
|
||||||
reasoning_content = model_output[:think_end_pos]
|
reasoning_content = model_output[:think_end_pos]
|
||||||
if think_end_pos > 0 and reasoning_content[-1] == "\n":
|
|
||||||
reasoning_content = reasoning_content[:-1]
|
|
||||||
|
|
||||||
remaining = model_output[think_end_pos + len(self.think_end_token) :]
|
remaining = model_output[think_end_pos + len(self.think_end_token) :]
|
||||||
|
|
||||||
# Skip newlines after </think>
|
# find <response> or <tool>
|
||||||
remaining = remaining.lstrip("\n")
|
response_pos = remaining.find(self.response_start_token)
|
||||||
|
tool_pos = remaining.find(self.tool_call_start_token)
|
||||||
|
|
||||||
# Check for response or tool_call
|
# <response> first
|
||||||
if remaining.startswith(self.response_start_token):
|
if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos):
|
||||||
response_pos = len(self.response_start_token)
|
# The content after the response_start position
|
||||||
remaining = remaining[response_pos:].lstrip("\n")
|
remaining_response = remaining[response_pos + len(self.response_start_token) :]
|
||||||
response_end_pos = remaining.find(self.response_end_token)
|
response_end_pos = remaining_response.find(self.response_end_token)
|
||||||
if response_end_pos != -1:
|
if response_end_pos != -1:
|
||||||
# Only strip the last newline before </response>, not all
|
response_content = remaining_response[:response_end_pos]
|
||||||
if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
|
|
||||||
response_content = remaining[: response_end_pos - 1]
|
|
||||||
else:
|
|
||||||
response_content = remaining[:response_end_pos]
|
|
||||||
else:
|
else:
|
||||||
# If no </response> found, return the rest as response content
|
response_content = remaining_response
|
||||||
response_content = remaining
|
# The content after the response_start position is tool_call
|
||||||
elif remaining.startswith(self.tool_call_start_token):
|
|
||||||
pass # No response content
|
|
||||||
else:
|
else:
|
||||||
# No thinking content found, return the whole input as reasoning
|
|
||||||
reasoning_content = model_output
|
reasoning_content = model_output
|
||||||
response_content = ""
|
response_content = ""
|
||||||
|
|
||||||
return reasoning_content, response_content
|
return reasoning_content, response_content
|
||||||
|
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||||
|
from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import (
|
||||||
|
ErnieX1ToolParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
"""Dummy tokenizer with minimal vocab for testing"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
|
||||||
|
|
||||||
|
|
||||||
|
class TestErnieX1ToolParser(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
self.tokenizer = DummyTokenizer()
|
||||||
|
self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer)
|
||||||
|
self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
|
||||||
|
# ---------------- Batch extraction tests ----------------
|
||||||
|
|
||||||
|
def test_extract_tool_calls_complete(self):
|
||||||
|
"""Test normal extraction of complete tool_call JSON"""
|
||||||
|
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertTrue(result.tools_called)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_extract_tool_calls_partial_arguments(self):
|
||||||
|
"""Test partial extraction when arguments incomplete"""
|
||||||
|
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北"</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertFalse(result.tools_called)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_extract_tool_calls_invalid_response_before_toolcall(self):
|
||||||
|
"""Test case where <response> before <tool_call> is invalid"""
|
||||||
|
output = '<response>hello</response><tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertFalse(result.tools_called)
|
||||||
|
self.assertIn("<response>", result.content)
|
||||||
|
|
||||||
|
def test_extract_tool_calls_no_toolcall(self):
|
||||||
|
"""Test when no tool_call tags are present"""
|
||||||
|
output = "no tool call here"
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertFalse(result.tools_called)
|
||||||
|
|
||||||
|
def test_extract_tool_calls_invalid_json(self):
|
||||||
|
"""Test tool_call with badly formatted JSON triggers fallback parser"""
|
||||||
|
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertFalse(result.tools_called)
|
||||||
|
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_extract_tool_calls_exception(self):
|
||||||
|
"""Force exception to cover error branch"""
|
||||||
|
with patch(
|
||||||
|
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
|
||||||
|
):
|
||||||
|
output = '<tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
|
||||||
|
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||||
|
self.assertFalse(result.tools_called)
|
||||||
|
|
||||||
|
# ---------------- Streaming extraction tests ----------------
|
||||||
|
|
||||||
|
def test_streaming_no_toolcall(self):
|
||||||
|
"""Streaming extraction returns normal DeltaMessage when no <tool_call>"""
|
||||||
|
result = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "abc", "abc", [], [], [], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsInstance(result, DeltaMessage)
|
||||||
|
self.assertEqual(result.content, "abc")
|
||||||
|
|
||||||
|
def test_streaming_skip_empty_chunk(self):
|
||||||
|
"""Streaming extraction skips empty chunks"""
|
||||||
|
result = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "<tool_call>", " ", [], [1], [1], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
def test_streaming_new_toolcall_and_name(self):
|
||||||
|
"""Streaming extraction detects new toolcall and extracts name"""
|
||||||
|
delta = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "<tool_call>", '<tool_call>{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(delta)
|
||||||
|
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
|
||||||
|
|
||||||
|
def test_streaming_partial_arguments(self):
|
||||||
|
"""Streaming extraction yields partial arguments deltas"""
|
||||||
|
text = '"arguments": {"location":'
|
||||||
|
delta = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsInstance(delta, DeltaMessage)
|
||||||
|
self.assertIn("arguments", delta.tool_calls[0].function.arguments)
|
||||||
|
|
||||||
|
def test_streaming_complete_arguments_and_end(self):
|
||||||
|
"""Streaming extraction completes arguments with brackets matched and closes tool_call"""
|
||||||
|
text = '"arguments": {"location": "北京"}}'
|
||||||
|
delta = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsInstance(delta, DeltaMessage)
|
||||||
|
# Also simulate closing tag
|
||||||
|
end_delta = self.parser.extract_tool_calls_streaming(
|
||||||
|
"", "</tool_call>", "</tool_call>", [], [2], [2], self.dummy_request.model_dump()
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(end_delta)
|
||||||
|
self.assertEqual(end_delta.content, "</tool_call>")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
265
tests/reasoning/test_reasoning_parser.py
Normal file
265
tests/reasoning/test_reasoning_parser.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||||
|
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
"""Minimal tokenizer with vocab for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vocab = {
|
||||||
|
"</think>": 100,
|
||||||
|
"<tool_call>": 101,
|
||||||
|
"</tool_call>": 102,
|
||||||
|
"<response>": 103,
|
||||||
|
"</response>": 104,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_vocab(self):
|
||||||
|
"""Return vocab dict for testing."""
|
||||||
|
return self.vocab
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningParser(ReasoningParser):
|
||||||
|
def is_reasoning_end(self, input_ids):
|
||||||
|
"""
|
||||||
|
Return True to simulate end of reasoning content.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def extract_content_ids(self, input_ids):
|
||||||
|
"""
|
||||||
|
Return input_ids directly for testing.
|
||||||
|
"""
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
def extract_reasoning_content(self, model_output, request):
|
||||||
|
"""
|
||||||
|
Used for testing non-streaming extraction.
|
||||||
|
"""
|
||||||
|
return model_output, model_output
|
||||||
|
|
||||||
|
def extract_reasoning_content_streaming(
|
||||||
|
self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return None for streaming extraction; minimal implementation for testing.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningParserManager(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Unit tests for ReasoningParserManager functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Save original registry to restore after each test.
|
||||||
|
"""
|
||||||
|
self.original_parsers = ReasoningParserManager.reasoning_parsers.copy()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""
|
||||||
|
Restore original registry to avoid test pollution.
|
||||||
|
"""
|
||||||
|
ReasoningParserManager.reasoning_parsers = self.original_parsers.copy()
|
||||||
|
|
||||||
|
def test_register_and_get_parser(self):
|
||||||
|
"""
|
||||||
|
Test that a parser can be registered and retrieved successfully.
|
||||||
|
Verifies normal registration and retrieval functionality.
|
||||||
|
"""
|
||||||
|
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser", force=True)
|
||||||
|
parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser")
|
||||||
|
self.assertIs(parser_cls, TestReasoningParser)
|
||||||
|
|
||||||
|
def test_register_duplicate_without_force_raises(self):
|
||||||
|
"""
|
||||||
|
Test that registering a parser with an existing name without force raises KeyError.
|
||||||
|
Ensures duplicate registrations are handled correctly.
|
||||||
|
"""
|
||||||
|
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True)
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False)
|
||||||
|
|
||||||
|
def test_register_non_subclass_raises(self):
|
||||||
|
"""
|
||||||
|
Test that registering a class not inheriting from ReasoningParser raises TypeError.
|
||||||
|
Ensures type safety for registered modules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class NotParser:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
ReasoningParserManager.register_module(module=NotParser, name="not_parser")
|
||||||
|
|
||||||
|
def test_get_unregistered_parser_raises(self):
|
||||||
|
"""
|
||||||
|
Test that retrieving a parser that was not registered raises KeyError.
|
||||||
|
Ensures get_reasoning_parser handles unknown names correctly.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
ReasoningParserManager.get_reasoning_parser("nonexistent_parser")
|
||||||
|
|
||||||
|
|
||||||
|
class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.parser = ErnieX1ReasoningParser(DummyTokenizer())
|
||||||
|
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
|
||||||
|
self.tokenizer = DummyTokenizer()
|
||||||
|
|
||||||
|
# ---- Streaming parsing ----
|
||||||
|
def test_streaming_thinking_content(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="",
|
||||||
|
current_text="a",
|
||||||
|
delta_text="a",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[200],
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "a")
|
||||||
|
|
||||||
|
def test_streaming_thinking_newline_preserved(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="abc",
|
||||||
|
current_text="abc\n",
|
||||||
|
delta_text="\n",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[201],
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.reasoning_content, "\n")
|
||||||
|
|
||||||
|
def test_streaming_thinking_end_tag(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="abc",
|
||||||
|
current_text="abc</think>",
|
||||||
|
delta_text="</think>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[self.parser.think_end_token_id],
|
||||||
|
)
|
||||||
|
self.assertIsNone(msg)
|
||||||
|
|
||||||
|
def test_streaming_response_content(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think><response>",
|
||||||
|
current_text="</think><response>h",
|
||||||
|
delta_text="h",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[202],
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "h")
|
||||||
|
|
||||||
|
def test_streaming_response_newline_preserved(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think><response>hi",
|
||||||
|
current_text="</think><response>hi\n",
|
||||||
|
delta_text="\n",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[203],
|
||||||
|
)
|
||||||
|
self.assertEqual(msg.content, "\n")
|
||||||
|
|
||||||
|
def test_streaming_response_ignore_tags(self):
|
||||||
|
self.assertIsNone(
|
||||||
|
self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think>",
|
||||||
|
current_text="</think><response>",
|
||||||
|
delta_text="<response>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[self.parser.vocab["<response>"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think><response>",
|
||||||
|
current_text="</think><response>\n",
|
||||||
|
delta_text="\n",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[204],
|
||||||
|
)
|
||||||
|
self.assertIsInstance(msg, DeltaMessage)
|
||||||
|
self.assertEqual(msg.content, "\n")
|
||||||
|
|
||||||
|
self.assertIsNone(
|
||||||
|
self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think><response>\n",
|
||||||
|
current_text="</think><response>\n</response>",
|
||||||
|
delta_text="</response>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[self.parser.vocab["</response>"]],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_streaming_tool_call(self):
|
||||||
|
msg = self.parser.extract_reasoning_content_streaming(
|
||||||
|
previous_text="</think>",
|
||||||
|
current_text="</think><tool_call>",
|
||||||
|
delta_text="<tool_call>",
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=[self.parser.vocab["<tool_call>"]],
|
||||||
|
)
|
||||||
|
self.assertIsNone(msg)
|
||||||
|
|
||||||
|
# ---- Batch parsing ----
|
||||||
|
def test_batch_reasoning_and_response(self):
|
||||||
|
text = "abc\n</think>\n<response>hello\nworld</response>"
|
||||||
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||||
|
self.assertEqual(reasoning, "abc\n")
|
||||||
|
self.assertEqual(response, "hello\nworld")
|
||||||
|
|
||||||
|
def test_batch_reasoning_and_tool_call(self):
|
||||||
|
text = "abc</think><tool_call>call_here"
|
||||||
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||||
|
self.assertEqual(reasoning, "abc")
|
||||||
|
self.assertEqual(response, "")
|
||||||
|
|
||||||
|
def test_batch_no_thinking_tag(self):
|
||||||
|
text = "no_thinking_here"
|
||||||
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||||
|
self.assertEqual(reasoning, "no_thinking_here")
|
||||||
|
self.assertEqual(response, "")
|
||||||
|
|
||||||
|
def test_batch_response_without_end_tag(self):
|
||||||
|
text = "abc</think><response>partial response"
|
||||||
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||||
|
self.assertEqual(reasoning, "abc")
|
||||||
|
self.assertEqual(response, "partial response")
|
||||||
|
|
||||||
|
def test_batch_preserve_all_newlines(self):
|
||||||
|
text = "abc\n</think>\n<response>line1\nline2\n</response>"
|
||||||
|
reasoning, response = self.parser.extract_reasoning_content(text, self.request)
|
||||||
|
self.assertEqual(reasoning, "abc\n")
|
||||||
|
self.assertEqual(response, "line1\nline2\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Reference in New Issue
Block a user