mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] add a new reasoning parser (#4571)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
* add new reasoning_parser initial commit * add parser file content * add register * ernie_test_reasoning_parser * support <tool_call> token and add tool_parser * add and fix unit tests * modify reasoning_parser * modify reasoning parser and tool parser * modify unit tests * modify reasoning_parser and tool_parser * modify unit tests * fix tool_parser * modify the logic of reasoning_parser and tool_parser * add and modify unit tests * standardize code style * simplify reasoning_parser and tool_parser * modify unit test
This commit is contained in:
@@ -15,10 +15,7 @@
|
||||
"""
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .ernie_45_vl_thinking_tool_parser import Ernie45VLThinkingToolParser
|
||||
from .ernie_x1_tool_parser import ErnieX1ToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"ErnieX1ToolParser",
|
||||
]
|
||||
__all__ = ["ToolParser", "ToolParserManager", "ErnieX1ToolParser", "Ernie45VLThinkingToolParser"]
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
"""Generate a random tool call ID"""
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
@ToolParserManager.register_module("ernie_45-vl-thinking")
|
||||
class Ernie45VLThinkingToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for Ernie model version 4.5.1.
|
||||
This parser handles tool calls with newline formats.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
|
||||
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
|
||||
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
self.valid = None
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if self.tool_call_start_token_id is None:
|
||||
self.tool_call_start_token_id = -1
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
|
||||
)
|
||||
|
||||
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
Supports XML-style formats with newlines:
|
||||
- XML format: <think>\n...\n</think>\n\n\n<tool_call>\n{...}\n</tool_call>\n...
|
||||
|
||||
Handles boundary cases:
|
||||
1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
|
||||
2. Only partial name: {"name": "get_we
|
||||
3. Only name and arguments field without content: {"name": "get_weather", "argume
|
||||
"""
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
function_call_arr = []
|
||||
remaining_text = model_output
|
||||
|
||||
think_end = remaining_text.find("</think>")
|
||||
think_end = think_end + len("</think>") if think_end != -1 else 0
|
||||
tool_begin = remaining_text.find("<tool_call>")
|
||||
if tool_begin != -1:
|
||||
middle_str = remaining_text[think_end:tool_begin]
|
||||
if len(middle_str.strip("\n")) > 0:
|
||||
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
||||
|
||||
while True:
|
||||
# Find the next <tool_call>
|
||||
tool_call_pos = remaining_text.find("<tool_call>")
|
||||
if tool_call_pos == -1:
|
||||
break
|
||||
|
||||
# Extract content after <tool_call>
|
||||
tool_content_start = tool_call_pos + len("<tool_call>")
|
||||
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
|
||||
|
||||
tool_json = ""
|
||||
if tool_content_end == -1:
|
||||
# Processing unclosed tool_call block (truncated case)
|
||||
tool_json = remaining_text[tool_content_start:].strip()
|
||||
remaining_text = "" # No more content to process
|
||||
else:
|
||||
# Processing closed </tool_call> block
|
||||
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
|
||||
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
|
||||
|
||||
if not tool_json:
|
||||
continue
|
||||
|
||||
# Process tool_json
|
||||
tool_json = tool_json.strip()
|
||||
if not tool_json.startswith("{"):
|
||||
tool_json = "{" + tool_json
|
||||
if not tool_json.endswith("}"):
|
||||
tool_json = tool_json + "}"
|
||||
|
||||
try:
|
||||
# Parsing strategy: First try standard json.loads
|
||||
try:
|
||||
tool_data = json.loads(tool_json)
|
||||
|
||||
if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
|
||||
function_call_arr.append(
|
||||
{
|
||||
"name": tool_data["name"],
|
||||
"arguments": tool_data["arguments"],
|
||||
"_is_complete": True, # Mark as complete
|
||||
}
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try partial_json_parser when standard parsing fails
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
try:
|
||||
tool_data = {}
|
||||
flags = Allow.ALL & ~Allow.STR
|
||||
|
||||
# Parse the name field
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
|
||||
if name_match:
|
||||
tool_data["name"] = name_match.group(1)
|
||||
|
||||
# Parse the arguments field
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
|
||||
if args_match:
|
||||
try:
|
||||
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
|
||||
except:
|
||||
tool_data["arguments"] = None
|
||||
|
||||
if isinstance(tool_data, dict):
|
||||
function_call_arr.append(
|
||||
{
|
||||
"name": tool_data.get("name", ""),
|
||||
"arguments": tool_data.get("arguments", {}),
|
||||
"_is_partial": True, # Mark as partial
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
|
||||
if not function_call_arr:
|
||||
data_processor_logger.error("No valid tool calls found")
|
||||
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
||||
|
||||
tool_calls = []
|
||||
all_complete = True # Initialize as all complete
|
||||
|
||||
for tool_call in function_call_arr:
|
||||
# Set flags
|
||||
is_complete = tool_call.get("_is_complete", False)
|
||||
is_partial = tool_call.get("_is_partial", False)
|
||||
|
||||
# If any tool call is incomplete or partial, mark all_complete as False
|
||||
if not is_complete or is_partial:
|
||||
all_complete = False
|
||||
|
||||
# Process arguments
|
||||
tool_args = tool_call.get("arguments", {})
|
||||
if not isinstance(tool_args, dict):
|
||||
tool_args = {}
|
||||
|
||||
try:
|
||||
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
|
||||
except:
|
||||
args_str = "{}"
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=FunctionCall(
|
||||
name=tool_call.get("name", ""),
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Only return tools_called=True if all tool calls are complete
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
|
||||
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: dict,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.valid is not None and not self.valid:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# Skip empty chunks
|
||||
if len(delta_text.strip()) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
delta = None
|
||||
# Use buffer to accumulate delta_text content
|
||||
self.buffer += delta_text
|
||||
|
||||
# Process the buffer content
|
||||
if "<tool_call>" in delta_text:
|
||||
if self.valid is None:
|
||||
tool_call_begin = current_text.find(self.tool_call_start_token)
|
||||
prefix = current_text[:tool_call_begin]
|
||||
prefix = prefix.strip("\n")
|
||||
if len(prefix) > 0 and not prefix.endswith("</think>"):
|
||||
self.valid = False
|
||||
return DeltaMessage(content=delta_text)
|
||||
self.valid = True
|
||||
self.current_tool_id = (
|
||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||
)
|
||||
self.current_tool_name_sent = False
|
||||
if len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||
|
||||
# 1. Try to parse the name field
|
||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||
if name_match:
|
||||
name = name_match.group(1)
|
||||
if name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
# Delete the processed name part from the buffer
|
||||
self.buffer = self.buffer[name_match.end() :]
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
# 2. Processing arguments field
|
||||
if '"arguments"' in self.buffer:
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||
if args_match:
|
||||
args_content = args_match.group(1)
|
||||
try:
|
||||
# Check if arguments field is complete by bracket matching
|
||||
if "}}" in args_content:
|
||||
matched_pos = -1
|
||||
for i, ch in enumerate(delta_text):
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
|
||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
|
||||
matched_pos = i
|
||||
break
|
||||
|
||||
if matched_pos >= 0:
|
||||
# Clean up bracket counts for next tool call
|
||||
truncate_text = delta_text[: matched_pos + 1]
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.buffer = self.buffer[args_match.end() :]
|
||||
return delta
|
||||
else:
|
||||
# No complete match yet
|
||||
return None
|
||||
else:
|
||||
# Return partial arguments
|
||||
for ch in delta_text:
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
if "</tool_call>" in self.buffer:
|
||||
end_pos = self.buffer.find("</tool_call>")
|
||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
@@ -17,6 +17,7 @@
|
||||
from fastdeploy.plugins import load_reasoning_parser_plugins
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .ernie_45_vl_thinking_reasoning_parser import Ernie45VLThinkingReasoningParser
|
||||
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||
from .qwen3_reasoning_parsers import Qwen3ReasoningParser
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"ErnieVLReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"ErnieX1ReasoningParser",
|
||||
"Ernie45VLThinkingReasoningParser",
|
||||
]
|
||||
|
||||
load_reasoning_parser_plugins()
|
||||
|
||||
138
fastdeploy/reasoning/ernie_45_vl_thinking_reasoning_parser.py
Normal file
138
fastdeploy/reasoning/ernie_45_vl_thinking_reasoning_parser.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("erine-45-vl-thinking")
|
||||
class Ernie45VLThinkingReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for ernir_vl model.
|
||||
|
||||
The ernie_vl model uses ...</think>... tokens to denote reasoning text
|
||||
within its output. The model provides a strict switch to disable reasoning
|
||||
output via the 'enable_thinking=False' parameter. This parser extracts the
|
||||
reasoning content enclosed by <think> and </think> tokens from the model's
|
||||
output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
self.tool_begin_token = "<tool_call>"
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
|
||||
)
|
||||
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
|
||||
if self.tool_begin_token_id is None:
|
||||
self.tool_begin_token_id = -1
|
||||
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
Handles streaming output where previous + delta = current.
|
||||
Uses token IDs for faster processing.
|
||||
For text abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
if self.think_end_token not in current_text:
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
if self._is_with_tool(current_text=current_text, current_token_ids=current_token_ids):
|
||||
if self.think_end_token in delta_text:
|
||||
think_begin = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:think_begin]
|
||||
return DeltaMessage(reasoning_content=reasoning_content)
|
||||
return None
|
||||
if self.think_end_token in delta_text:
|
||||
reasoning_content, _, content = delta_text.partition(self.think_end_token)
|
||||
striped_content = content.strip("\n")
|
||||
if len(striped_content) == 0:
|
||||
return DeltaMessage(reasoning_content=reasoning_content) if reasoning_content else None
|
||||
return (
|
||||
DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
if reasoning_content
|
||||
else DeltaMessage(content=content)
|
||||
)
|
||||
think_end = current_text.find(self.think_end_token) + len(self.think_end_token)
|
||||
suffix = current_text[think_end:]
|
||||
striped_suffix = suffix.strip("\n")
|
||||
if len(striped_suffix) == 0:
|
||||
return None
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
For text abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
|
||||
# Check if the model output contains the </think> tokens.
|
||||
if self.think_end_token not in model_output:
|
||||
return model_output, ""
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
if self.tool_begin_token in content:
|
||||
prefix, _, _ = content.partition(self.tool_begin_token)
|
||||
prefix_strip = prefix.lstrip("\n")
|
||||
if len(prefix_strip) > 0:
|
||||
return reasoning_content, content
|
||||
return reasoning_content, ""
|
||||
return reasoning_content, content
|
||||
|
||||
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
|
||||
think_end_index = current_text.find(self.think_end_token)
|
||||
think_end = think_end_index + len(self.think_end_token)
|
||||
middle_str = current_text[think_end:]
|
||||
if self.tool_begin_token_id in current_token_ids:
|
||||
prefix, _, _ = middle_str.partition(self.tool_begin_token)
|
||||
striped_prefix = prefix.strip("\n")
|
||||
if len(striped_prefix) > 0:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.entrypoints.openai.tool_parsers.ernie_45_vl_thinking_tool_parser import (
|
||||
Ernie45VLThinkingToolParser,
|
||||
)
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
"""Dummy tokenizer with minimal vocab for testing"""
|
||||
|
||||
def __init__(self):
|
||||
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
|
||||
|
||||
|
||||
class TestErnie45VLThinkingToolParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
class DummyTokenizer:
|
||||
def __init__(self):
|
||||
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.vocab
|
||||
|
||||
self.tokenizer = DummyTokenizer()
|
||||
self.parser = Ernie45VLThinkingToolParser(tokenizer=self.tokenizer)
|
||||
self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
|
||||
|
||||
# ---------------- Batch extraction tests ----------------
|
||||
|
||||
def test_extract_tool_calls_complete(self):
|
||||
"""Test normal extraction of complete tool_call JSON"""
|
||||
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertTrue(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_partial_arguments(self):
|
||||
"""Test partial extraction when arguments incomplete"""
|
||||
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北"</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_no_toolcall(self):
|
||||
"""Test when no tool_call tags are present"""
|
||||
output = "no tool call here"
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
|
||||
def test_extract_tool_calls_invalid_json(self):
|
||||
"""Test tool_call with badly formatted JSON triggers fallback parser"""
|
||||
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_extract_tool_calls_exception(self):
|
||||
"""Force exception to cover error branch"""
|
||||
with patch(
|
||||
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
|
||||
):
|
||||
output = '<tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
|
||||
def test_extract_tool_calls_illegal(self):
|
||||
output = '</think>abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(
|
||||
result.content,
|
||||
'</think>abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>',
|
||||
)
|
||||
output = 'abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
|
||||
result = self.parser.extract_tool_calls(output, self.dummy_request)
|
||||
self.assertFalse(result.tools_called)
|
||||
self.assertEqual(
|
||||
result.content, 'abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
|
||||
)
|
||||
|
||||
# ---------------- Streaming extraction tests ----------------
|
||||
|
||||
def test_streaming_no_toolcall(self):
|
||||
"""Streaming extraction returns normal DeltaMessage when no <tool_call>"""
|
||||
result = self.parser.extract_tool_calls_streaming(
|
||||
"", "abc", "abc", [], [], [], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.tool_calls)
|
||||
self.assertEqual(result.content, "abc")
|
||||
|
||||
def test_streaming_skip_empty_chunk(self):
|
||||
"""Streaming extraction skips empty chunks"""
|
||||
result = self.parser.extract_tool_calls_streaming(
|
||||
"", "<tool_call>", " ", [], [1], [1], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_streaming_new_toolcall_and_name(self):
|
||||
"""Streaming extraction detects new toolcall and extracts name"""
|
||||
delta = self.parser.extract_tool_calls_streaming(
|
||||
"", "<tool_call>", '<tool_call>{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsNotNone(delta)
|
||||
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
def test_streaming_partial_arguments(self):
|
||||
"""Streaming extraction yields partial arguments deltas"""
|
||||
text = '"arguments": {"location":'
|
||||
delta = self.parser.extract_tool_calls_streaming(
|
||||
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsInstance(delta, DeltaMessage)
|
||||
self.assertIn("arguments", delta.tool_calls[0].function.arguments)
|
||||
|
||||
def test_streaming_complete_arguments_and_end(self):
|
||||
"""Streaming extraction completes arguments with brackets matched and closes tool_call"""
|
||||
text = '"arguments": {"location": "北京"}}'
|
||||
delta = self.parser.extract_tool_calls_streaming(
|
||||
"", "<tool_call>" + text, "<tool_call>" + text, [], [1], [1], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsInstance(delta, DeltaMessage)
|
||||
# Also simulate closing tag
|
||||
end_delta = self.parser.extract_tool_calls_streaming(
|
||||
"<tool_call>" + text,
|
||||
"<tool_call>" + text + "</tool_call>",
|
||||
"</tool_call>",
|
||||
[1],
|
||||
[1, 2],
|
||||
[2],
|
||||
self.dummy_request.model_dump(),
|
||||
)
|
||||
self.assertIsNone(end_delta)
|
||||
|
||||
def test_streaming_no_tool_illegal(self):
|
||||
result = self.parser.extract_tool_calls_streaming(
|
||||
"", "abc<tool_call>", "abc<tool_call>", [], [], [], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.tool_calls)
|
||||
self.assertEqual(result.content, "abc<tool_call>")
|
||||
result = self.parser.extract_tool_calls_streaming(
|
||||
"", "</think>abc<tool_call>", "</think>abc<tool_call>", [], [], [], self.dummy_request.model_dump()
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.tool_calls)
|
||||
self.assertEqual(result.content, "</think>abc<tool_call>")
|
||||
|
||||
def test_streaming_tool_with_reasoning(self):
|
||||
delta = self.parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'</think><tool_call>{"name": "get_weather"',
|
||||
'</think><tool_call>{"name": "get_weather"',
|
||||
[],
|
||||
[1],
|
||||
[1],
|
||||
self.dummy_request.model_dump(),
|
||||
)
|
||||
self.assertIsNotNone(delta)
|
||||
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
|
||||
delta = self.parser.extract_tool_calls_streaming(
|
||||
"",
|
||||
'</think>\n\n<tool_call>{"name": "get_weather"',
|
||||
'</think>\n\n<tool_call>{"name": "get_weather"',
|
||||
[],
|
||||
[1],
|
||||
[1],
|
||||
self.dummy_request.model_dump(),
|
||||
)
|
||||
self.assertIsNotNone(delta)
|
||||
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -18,6 +18,9 @@ import unittest
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from fastdeploy.reasoning.ernie_45_vl_thinking_reasoning_parser import (
|
||||
Ernie45VLThinkingReasoningParser,
|
||||
)
|
||||
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||
|
||||
|
||||
@@ -261,5 +264,166 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
self.assertEqual(response, "line1\nline2\n")
|
||||
|
||||
|
||||
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tokenizer = DummyTokenizer()
|
||||
self.parser = Ernie45VLThinkingReasoningParser(tokenizer=self.tokenizer)
|
||||
self.test_request = ChatCompletionRequest(
|
||||
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
|
||||
)
|
||||
|
||||
def test_streaming_non_reasoning(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="a",
|
||||
delta_text="a",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[200],
|
||||
delta_token_ids=[200],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "a")
|
||||
self.assertIsNone(result.content)
|
||||
|
||||
def test_streaming_with_reasoning(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="ab",
|
||||
current_text="ab</think>",
|
||||
delta_text="</think>",
|
||||
previous_token_ids=[200, 201],
|
||||
current_token_ids=[200, 201, 100],
|
||||
delta_token_ids=[100],
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_streaming_with_reasoning_and_content(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="ab",
|
||||
current_text="ab</think>\n\ncd",
|
||||
delta_text="</think>\n\ncd",
|
||||
previous_token_ids=[200, 201],
|
||||
current_token_ids=[200, 201, 100, 300, 400],
|
||||
delta_token_ids=[100, 300, 400],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.reasoning_content)
|
||||
self.assertEqual(result.content, "\n\ncd")
|
||||
|
||||
def test_streaming_with_reasoning_new_line(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="abc",
|
||||
current_text="abc</think>\n\n",
|
||||
delta_text="</think>\n\n",
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100],
|
||||
delta_token_ids=[100],
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_streaming_with_reasoning_and_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="abc",
|
||||
current_text="abc</think>\n\n<tool_call>",
|
||||
delta_text="</think>\n\n<tool_call>",
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||
delta_token_ids=[100, 200, 101],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "")
|
||||
|
||||
def test_streaming_with_reasoning_and_illegal_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="abc</think>",
|
||||
current_text="abc</think>\n\nhello<tool_call>",
|
||||
delta_text="\n\nhello<tool_call>",
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 101],
|
||||
delta_token_ids=[109, 200, 101],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.content, "\n\nhello<tool_call>")
|
||||
|
||||
def test_streaming_with_reasoning_no_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="abc",
|
||||
current_text="abchello</think>\nworld",
|
||||
delta_text="hello</think>\nworld",
|
||||
previous_token_ids=[200, 201, 202],
|
||||
current_token_ids=[200, 201, 202, 100, 200, 110],
|
||||
delta_token_ids=[100, 200, 110],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "hello")
|
||||
self.assertEqual(result.content, "\nworld")
|
||||
|
||||
def test_streaming_reasoning_previous_no_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="</think>",
|
||||
current_text="</think>\nhello",
|
||||
delta_text="\nhello",
|
||||
previous_token_ids=[100],
|
||||
current_token_ids=[100, 110, 111],
|
||||
delta_token_ids=[110, 111],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertIsNone(result.reasoning_content)
|
||||
self.assertEqual(result.content, "\nhello")
|
||||
|
||||
def test_streaming_no_reasoning_previous_tool(self):
|
||||
result = self.parser.extract_reasoning_content_streaming(
|
||||
previous_text="<tool_call>",
|
||||
current_text="<tool_call>hello",
|
||||
delta_text="hello",
|
||||
previous_token_ids=[101],
|
||||
current_token_ids=[101, 110],
|
||||
delta_token_ids=[110],
|
||||
)
|
||||
self.assertIsInstance(result, DeltaMessage)
|
||||
self.assertEqual(result.reasoning_content, "hello")
|
||||
|
||||
def test_batch_no_think_end(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="direct response", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "direct response")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_no_think_end_with_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="direct response<tool_call>abc", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "direct response<tool_call>abc")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_think_end_normal_content(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\nresponse", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\nresponse")
|
||||
|
||||
def test_batch_think_end_with_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\n<tool_call>tool params</tool_call>", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "")
|
||||
|
||||
def test_batch_think_end_with_illegal_tool(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
|
||||
|
||||
def test_batch_think_end_content_with_newline(self):
|
||||
reasoning, content = self.parser.extract_reasoning_content(
|
||||
model_output="reasoning</think>\n\n actual response", request=self.test_request
|
||||
)
|
||||
self.assertEqual(reasoning, "reasoning")
|
||||
self.assertEqual(content, "\n\n actual response")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user