[Feature] add a new reasoning parser (#4571)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* add new reasoning_parser initial commit

* add parser file content

* add register

* ernie_test_reasoning_parser

* support <tool_call> token and add tool_parser

* add and fix unit tests

* modify reasoning_parser

* modify reasoning parser and tool parser

* modify unit tests

* modify reasoning_parser and tool_parser

* modify unit tests

* fix tool_parser

* modify the logic of reasoning_parser and tool_parser

* add and modify unit tests

* standardize code style

* simplify reasoning_parser and tool_parser

* modify unit test
This commit is contained in:
kxz2002
2025-10-29 18:16:50 +08:00
committed by GitHub
parent 19df1aec2b
commit c30bfb294f
6 changed files with 860 additions and 5 deletions

View File

@@ -15,10 +15,7 @@
"""
from .abstract_tool_parser import ToolParser, ToolParserManager
from .ernie_45_vl_thinking_tool_parser import Ernie45VLThinkingToolParser
from .ernie_x1_tool_parser import ErnieX1ToolParser
__all__ = [
"ToolParser",
"ToolParserManager",
"ErnieX1ToolParser",
]
__all__ = ["ToolParser", "ToolParserManager", "ErnieX1ToolParser", "Ernie45VLThinkingToolParser"]

View File

@@ -0,0 +1,361 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import re
import uuid
from collections.abc import Sequence
from typing import Union
import partial_json_parser
def random_tool_call_id() -> str:
"""Generate a random tool call ID"""
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from fastdeploy.utils import data_processor_logger
@ToolParserManager.register_module("ernie_45-vl-thinking")
class Ernie45VLThinkingToolParser(ToolParser):
"""
Tool parser for Ernie model version 4.5.1.
This parser handles tool calls with newline formats.
"""
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.valid = None
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None:
self.tool_call_start_token_id = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
)
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
Supports XML-style formats with newlines:
- XML format: <think>\n...\n</think>\n\n\n<tool_call>\n{...}\n</tool_call>\n...
Handles boundary cases:
1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
2. Only partial name: {"name": "get_we
3. Only name and arguments field without content: {"name": "get_weather", "argume
"""
try:
tool_calls = []
function_call_arr = []
remaining_text = model_output
think_end = remaining_text.find("</think>")
think_end = think_end + len("</think>") if think_end != -1 else 0
tool_begin = remaining_text.find("<tool_call>")
if tool_begin != -1:
middle_str = remaining_text[think_end:tool_begin]
if len(middle_str.strip("\n")) > 0:
return ExtractedToolCallInformation(tools_called=False, content=model_output)
while True:
# Find the next <tool_call>
tool_call_pos = remaining_text.find("<tool_call>")
if tool_call_pos == -1:
break
# Extract content after <tool_call>
tool_content_start = tool_call_pos + len("<tool_call>")
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
tool_json = ""
if tool_content_end == -1:
# Processing unclosed tool_call block (truncated case)
tool_json = remaining_text[tool_content_start:].strip()
remaining_text = "" # No more content to process
else:
# Processing closed </tool_call> block
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
if not tool_json:
continue
# Process tool_json
tool_json = tool_json.strip()
if not tool_json.startswith("{"):
tool_json = "{" + tool_json
if not tool_json.endswith("}"):
tool_json = tool_json + "}"
try:
# Parsing strategy: First try standard json.loads
try:
tool_data = json.loads(tool_json)
if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
function_call_arr.append(
{
"name": tool_data["name"],
"arguments": tool_data["arguments"],
"_is_complete": True, # Mark as complete
}
)
continue
except json.JSONDecodeError:
pass
# Try partial_json_parser when standard parsing fails
from partial_json_parser.core.options import Allow
try:
tool_data = {}
flags = Allow.ALL & ~Allow.STR
# Parse the name field
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
if name_match:
tool_data["name"] = name_match.group(1)
# Parse the arguments field
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
if args_match:
try:
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
except:
tool_data["arguments"] = None
if isinstance(tool_data, dict):
function_call_arr.append(
{
"name": tool_data.get("name", ""),
"arguments": tool_data.get("arguments", {}),
"_is_partial": True, # Mark as partial
}
)
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue
if not function_call_arr:
data_processor_logger.error("No valid tool calls found")
return ExtractedToolCallInformation(tools_called=False, content=model_output)
tool_calls = []
all_complete = True # Initialize as all complete
for tool_call in function_call_arr:
# Set flags
is_complete = tool_call.get("_is_complete", False)
is_partial = tool_call.get("_is_partial", False)
# If any tool call is incomplete or partial, mark all_complete as False
if not is_complete or is_partial:
all_complete = False
# Process arguments
tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict):
tool_args = {}
try:
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
except:
args_str = "{}"
tool_calls.append(
ToolCall(
type="function",
id=random_tool_call_id(),
function=FunctionCall(
name=tool_call.get("name", ""),
arguments=args_str,
),
)
)
# Only return tools_called=True if all tool calls are complete
return ExtractedToolCallInformation(
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
)
except Exception as e:
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: dict,
) -> Union[DeltaMessage, None]:
if self.tool_call_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
if self.valid is not None and not self.valid:
return DeltaMessage(content=delta_text)
# Skip empty chunks
if len(delta_text.strip()) == 0:
return None
try:
delta = None
# Use buffer to accumulate delta_text content
self.buffer += delta_text
# Process the buffer content
if "<tool_call>" in delta_text:
if self.valid is None:
tool_call_begin = current_text.find(self.tool_call_start_token)
prefix = current_text[:tool_call_begin]
prefix = prefix.strip("\n")
if len(prefix) > 0 and not prefix.endswith("</think>"):
self.valid = False
return DeltaMessage(content=delta_text)
self.valid = True
self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
)
self.current_tool_name_sent = False
if len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
# 1. Try to parse the name field
if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
if name_match:
name = name_match.group(1)
if name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
)
]
)
# Delete the processed name part from the buffer
self.buffer = self.buffer[name_match.end() :]
self.current_tool_name_sent = True
return delta
# 2. Processing arguments field
if '"arguments"' in self.buffer:
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match:
args_content = args_match.group(1)
try:
# Check if arguments field is complete by bracket matching
if "}}" in args_content:
matched_pos = -1
for i, ch in enumerate(delta_text):
if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
matched_pos = i
break
if matched_pos >= 0:
# Clean up bracket counts for next tool call
truncate_text = delta_text[: matched_pos + 1]
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
exclude_none=True
),
)
]
)
self.buffer = self.buffer[args_match.end() :]
return delta
else:
# No complete match yet
return None
else:
# Return partial arguments
for ch in delta_text:
if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
)
]
)
return delta
except Exception as e:
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
return None
if "</tool_call>" in self.buffer:
end_pos = self.buffer.find("</tool_call>")
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
self.streamed_args_for_tool.append("")
return delta
except Exception as e:
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
return None

View File

@@ -17,6 +17,7 @@
from fastdeploy.plugins import load_reasoning_parser_plugins
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .ernie_45_vl_thinking_reasoning_parser import Ernie45VLThinkingReasoningParser
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
from .qwen3_reasoning_parsers import Qwen3ReasoningParser
@@ -27,6 +28,7 @@ __all__ = [
"ErnieVLReasoningParser",
"Qwen3ReasoningParser",
"ErnieX1ReasoningParser",
"Ernie45VLThinkingReasoningParser",
]
load_reasoning_parser_plugins()

View File

@@ -0,0 +1,138 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from collections.abc import Sequence
from typing import Optional, Union
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
@ReasoningParserManager.register_module("erine-45-vl-thinking")
class Ernie45VLThinkingReasoningParser(ReasoningParser):
"""
Reasoning parser for ernir_vl model.
The ernie_vl model uses ...</think>... tokens to denote reasoning text
within its output. The model provides a strict switch to disable reasoning
output via the 'enable_thinking=False' parameter. This parser extracts the
reasoning content enclosed by <think> and </think> tokens from the model's
output.
"""
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.tool_begin_token = "<tool_call>"
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
)
self.think_end_token_id = self.vocab.get(self.think_end_token)
self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
if self.tool_begin_token_id is None:
self.tool_begin_token_id = -1
if self.think_end_token_id is None:
raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
if self.think_end_token not in current_text:
return DeltaMessage(reasoning_content=delta_text)
# Skip single special tokens
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
if self._is_with_tool(current_text=current_text, current_token_ids=current_token_ids):
if self.think_end_token in delta_text:
think_begin = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:think_begin]
return DeltaMessage(reasoning_content=reasoning_content)
return None
if self.think_end_token in delta_text:
reasoning_content, _, content = delta_text.partition(self.think_end_token)
striped_content = content.strip("\n")
if len(striped_content) == 0:
return DeltaMessage(reasoning_content=reasoning_content) if reasoning_content else None
return (
DeltaMessage(reasoning_content=reasoning_content, content=content)
if reasoning_content
else DeltaMessage(content=content)
)
think_end = current_text.find(self.think_end_token) + len(self.think_end_token)
suffix = current_text[think_end:]
striped_suffix = suffix.strip("\n")
if len(striped_suffix) == 0:
return None
return DeltaMessage(content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
For text abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the model output contains the </think> tokens.
if self.think_end_token not in model_output:
return model_output, ""
reasoning_content, _, content = model_output.partition(self.think_end_token)
if self.tool_begin_token in content:
prefix, _, _ = content.partition(self.tool_begin_token)
prefix_strip = prefix.lstrip("\n")
if len(prefix_strip) > 0:
return reasoning_content, content
return reasoning_content, ""
return reasoning_content, content
def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
think_end_index = current_text.find(self.think_end_token)
think_end = think_end_index + len(self.think_end_token)
middle_str = current_text[think_end:]
if self.tool_begin_token_id in current_token_ids:
prefix, _, _ = middle_str.partition(self.tool_begin_token)
striped_prefix = prefix.strip("\n")
if len(striped_prefix) > 0:
return False
return True
return False

View File

@@ -0,0 +1,193 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from unittest.mock import patch
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.entrypoints.openai.tool_parsers.ernie_45_vl_thinking_tool_parser import (
Ernie45VLThinkingToolParser,
)
class DummyTokenizer:
"""Dummy tokenizer with minimal vocab for testing"""
def __init__(self):
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
class TestErnie45VLThinkingToolParser(unittest.TestCase):
def setUp(self):
class DummyTokenizer:
def __init__(self):
self.vocab = {"<tool_call>": 1, "</tool_call>": 2}
def get_vocab(self):
return self.vocab
self.tokenizer = DummyTokenizer()
self.parser = Ernie45VLThinkingToolParser(tokenizer=self.tokenizer)
self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
# ---------------- Batch extraction tests ----------------
def test_extract_tool_calls_complete(self):
"""Test normal extraction of complete tool_call JSON"""
output = '<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertTrue(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_partial_arguments(self):
"""Test partial extraction when arguments incomplete"""
output = '<tool_call>{"name": "get_weather", "arguments": {"location": ""</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_no_toolcall(self):
"""Test when no tool_call tags are present"""
output = "no tool call here"
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
def test_extract_tool_calls_invalid_json(self):
"""Test tool_call with badly formatted JSON triggers fallback parser"""
output = '<tool_call>"name": "get_weather", "arguments": {</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(result.tool_calls[0].function.name, "get_weather")
def test_extract_tool_calls_exception(self):
"""Force exception to cover error branch"""
with patch(
"fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
):
output = '<tool_call>{"name": "get_weather", "arguments": {}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
def test_extract_tool_calls_illegal(self):
output = '</think>abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(
result.content,
'</think>abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>',
)
output = 'abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
result = self.parser.extract_tool_calls(output, self.dummy_request)
self.assertFalse(result.tools_called)
self.assertEqual(
result.content, 'abc<tool_call>{"name": "get_weather", "arguments": {"location": "北京"}}</tool_call>'
)
# ---------------- Streaming extraction tests ----------------
def test_streaming_no_toolcall(self):
"""Streaming extraction returns normal DeltaMessage when no <tool_call>"""
result = self.parser.extract_tool_calls_streaming(
"", "abc", "abc", [], [], [], self.dummy_request.model_dump()
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.tool_calls)
self.assertEqual(result.content, "abc")
def test_streaming_skip_empty_chunk(self):
"""Streaming extraction skips empty chunks"""
result = self.parser.extract_tool_calls_streaming(
"", "<tool_call>", " ", [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsNone(result)
def test_streaming_new_toolcall_and_name(self):
"""Streaming extraction detects new toolcall and extracts name"""
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>", '<tool_call>{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsNotNone(delta)
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
def test_streaming_partial_arguments(self):
"""Streaming extraction yields partial arguments deltas"""
text = '"arguments": {"location":'
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>" + text, text, [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsInstance(delta, DeltaMessage)
self.assertIn("arguments", delta.tool_calls[0].function.arguments)
def test_streaming_complete_arguments_and_end(self):
"""Streaming extraction completes arguments with brackets matched and closes tool_call"""
text = '"arguments": {"location": "北京"}}'
delta = self.parser.extract_tool_calls_streaming(
"", "<tool_call>" + text, "<tool_call>" + text, [], [1], [1], self.dummy_request.model_dump()
)
self.assertIsInstance(delta, DeltaMessage)
# Also simulate closing tag
end_delta = self.parser.extract_tool_calls_streaming(
"<tool_call>" + text,
"<tool_call>" + text + "</tool_call>",
"</tool_call>",
[1],
[1, 2],
[2],
self.dummy_request.model_dump(),
)
self.assertIsNone(end_delta)
def test_streaming_no_tool_illegal(self):
result = self.parser.extract_tool_calls_streaming(
"", "abc<tool_call>", "abc<tool_call>", [], [], [], self.dummy_request.model_dump()
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.tool_calls)
self.assertEqual(result.content, "abc<tool_call>")
result = self.parser.extract_tool_calls_streaming(
"", "</think>abc<tool_call>", "</think>abc<tool_call>", [], [], [], self.dummy_request.model_dump()
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.tool_calls)
self.assertEqual(result.content, "</think>abc<tool_call>")
def test_streaming_tool_with_reasoning(self):
delta = self.parser.extract_tool_calls_streaming(
"",
'</think><tool_call>{"name": "get_weather"',
'</think><tool_call>{"name": "get_weather"',
[],
[1],
[1],
self.dummy_request.model_dump(),
)
self.assertIsNotNone(delta)
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
delta = self.parser.extract_tool_calls_streaming(
"",
'</think>\n\n<tool_call>{"name": "get_weather"',
'</think>\n\n<tool_call>{"name": "get_weather"',
[],
[1],
[1],
self.dummy_request.model_dump(),
)
self.assertIsNotNone(delta)
self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
if __name__ == "__main__":
unittest.main()

View File

@@ -18,6 +18,9 @@ import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
from fastdeploy.reasoning.ernie_45_vl_thinking_reasoning_parser import (
Ernie45VLThinkingReasoningParser,
)
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
@@ -261,5 +264,166 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
self.assertEqual(response, "line1\nline2\n")
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
def setUp(self):
self.tokenizer = DummyTokenizer()
self.parser = Ernie45VLThinkingReasoningParser(tokenizer=self.tokenizer)
self.test_request = ChatCompletionRequest(
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
)
def test_streaming_non_reasoning(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[200],
delta_token_ids=[200],
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "a")
self.assertIsNone(result.content)
def test_streaming_with_reasoning(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="ab",
current_text="ab</think>",
delta_text="</think>",
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100],
delta_token_ids=[100],
)
self.assertIsNone(result)
def test_streaming_with_reasoning_and_content(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="ab",
current_text="ab</think>\n\ncd",
delta_text="</think>\n\ncd",
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100, 300, 400],
delta_token_ids=[100, 300, 400],
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
self.assertEqual(result.content, "\n\ncd")
def test_streaming_with_reasoning_new_line(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc</think>\n\n",
delta_text="</think>\n\n",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100],
delta_token_ids=[100],
)
self.assertIsNone(result)
def test_streaming_with_reasoning_and_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc</think>\n\n<tool_call>",
delta_text="</think>\n\n<tool_call>",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[100, 200, 101],
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "")
def test_streaming_with_reasoning_and_illegal_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc</think>",
current_text="abc</think>\n\nhello<tool_call>",
delta_text="\n\nhello<tool_call>",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[109, 200, 101],
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "\n\nhello<tool_call>")
def test_streaming_with_reasoning_no_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abchello</think>\nworld",
delta_text="hello</think>\nworld",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 110],
delta_token_ids=[100, 200, 110],
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
self.assertEqual(result.content, "\nworld")
def test_streaming_reasoning_previous_no_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="</think>",
current_text="</think>\nhello",
delta_text="\nhello",
previous_token_ids=[100],
current_token_ids=[100, 110, 111],
delta_token_ids=[110, 111],
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
self.assertEqual(result.content, "\nhello")
def test_streaming_no_reasoning_previous_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="<tool_call>",
current_text="<tool_call>hello",
delta_text="hello",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
def test_batch_no_think_end(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct response", request=self.test_request
)
self.assertEqual(reasoning, "direct response")
self.assertEqual(content, "")
def test_batch_no_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct response<tool_call>abc", request=self.test_request
)
self.assertEqual(reasoning, "direct response<tool_call>abc")
self.assertEqual(content, "")
def test_batch_think_end_normal_content(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\nresponse", request=self.test_request
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nresponse")
def test_batch_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\n<tool_call>tool params</tool_call>", request=self.test_request
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "")
def test_batch_think_end_with_illegal_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\nABC\n<tool_call>tool params</tool_call>", request=self.test_request
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nABC\n<tool_call>tool params</tool_call>")
def test_batch_think_end_content_with_newline(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning</think>\n\n actual response", request=self.test_request
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\n\n actual response")
if __name__ == "__main__":
unittest.main()