diff --git a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py
index 2078a8c9f..a4df47ac9 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py
@@ -15,10 +15,7 @@
"""
from .abstract_tool_parser import ToolParser, ToolParserManager
+from .ernie_45_vl_thinking_tool_parser import Ernie45VLThinkingToolParser
from .ernie_x1_tool_parser import ErnieX1ToolParser
-__all__ = [
- "ToolParser",
- "ToolParserManager",
- "ErnieX1ToolParser",
-]
+__all__ = ["ToolParser", "ToolParserManager", "ErnieX1ToolParser", "Ernie45VLThinkingToolParser"]
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py
new file mode 100644
index 000000000..131c17e6a
--- /dev/null
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_45_vl_thinking_tool_parser.py
@@ -0,0 +1,361 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import json
+import re
+import uuid
+from collections.abc import Sequence
+from typing import Union
+
+import partial_json_parser
+
+
+def random_tool_call_id() -> str:
+ """Generate a random tool call ID"""
+ return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+
+
+from fastdeploy.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser,
+ ToolParserManager,
+)
+from fastdeploy.utils import data_processor_logger
+
+
+@ToolParserManager.register_module("ernie_45-vl-thinking")
+class Ernie45VLThinkingToolParser(ToolParser):
+ """
+ Tool parser for Ernie model version 4.5.1.
+ This parser handles tool calls with newline formats.
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id: int = -1
+ self.current_tool_name_sent: bool = False
+ self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
+ self.buffer: str = "" # buffer for accumulating unprocessed streaming content
+ self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
+ self.tool_call_start_token: str = ""
+ self.tool_call_end_token: str = ""
+ self.valid = None
+
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+ if self.tool_call_start_token_id is None:
+ self.tool_call_start_token_id = -1
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolCallParser constructor during construction."
+ )
+
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ """
+ Extract the tool calls from a complete model response.
+ Supports XML-style formats with newlines:
+ - XML format: \n...\n\n\n\n\n{...}\n\n...
+
+ Handles boundary cases:
+ 1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
+ 2. Only partial name: {"name": "get_we
+ 3. Only name and arguments field without content: {"name": "get_weather", "argume
+ """
+
+ try:
+ tool_calls = []
+
+ function_call_arr = []
+ remaining_text = model_output
+
+ think_end = remaining_text.find("")
+ think_end = think_end + len("") if think_end != -1 else 0
+ tool_begin = remaining_text.find("")
+ if tool_begin != -1:
+ middle_str = remaining_text[think_end:tool_begin]
+ if len(middle_str.strip("\n")) > 0:
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
+
+ while True:
+ # Find the next
+ tool_call_pos = remaining_text.find("")
+ if tool_call_pos == -1:
+ break
+
+ # Extract content after
+ tool_content_start = tool_call_pos + len("")
+ tool_content_end = remaining_text.find("", tool_content_start)
+
+ tool_json = ""
+ if tool_content_end == -1:
+ # Processing unclosed tool_call block (truncated case)
+ tool_json = remaining_text[tool_content_start:].strip()
+ remaining_text = "" # No more content to process
+ else:
+ # Processing closed block
+ tool_json = remaining_text[tool_content_start:tool_content_end].strip()
+ remaining_text = remaining_text[tool_content_end + len("") :]
+
+ if not tool_json:
+ continue
+
+ # Process tool_json
+ tool_json = tool_json.strip()
+ if not tool_json.startswith("{"):
+ tool_json = "{" + tool_json
+ if not tool_json.endswith("}"):
+ tool_json = tool_json + "}"
+
+ try:
+ # Parsing strategy: First try standard json.loads
+ try:
+ tool_data = json.loads(tool_json)
+
+ if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
+ function_call_arr.append(
+ {
+ "name": tool_data["name"],
+ "arguments": tool_data["arguments"],
+ "_is_complete": True, # Mark as complete
+ }
+ )
+ continue
+ except json.JSONDecodeError:
+ pass
+
+ # Try partial_json_parser when standard parsing fails
+ from partial_json_parser.core.options import Allow
+
+ try:
+ tool_data = {}
+ flags = Allow.ALL & ~Allow.STR
+
+ # Parse the name field
+ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
+ if name_match:
+ tool_data["name"] = name_match.group(1)
+
+ # Parse the arguments field
+ args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
+ if args_match:
+ try:
+ tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
+ except:
+ tool_data["arguments"] = None
+
+ if isinstance(tool_data, dict):
+ function_call_arr.append(
+ {
+ "name": tool_data.get("name", ""),
+ "arguments": tool_data.get("arguments", {}),
+ "_is_partial": True, # Mark as partial
+ }
+ )
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+
+ if not function_call_arr:
+ data_processor_logger.error("No valid tool calls found")
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
+
+ tool_calls = []
+ all_complete = True # Initialize as all complete
+
+ for tool_call in function_call_arr:
+ # Set flags
+ is_complete = tool_call.get("_is_complete", False)
+ is_partial = tool_call.get("_is_partial", False)
+
+ # If any tool call is incomplete or partial, mark all_complete as False
+ if not is_complete or is_partial:
+ all_complete = False
+
+ # Process arguments
+ tool_args = tool_call.get("arguments", {})
+ if not isinstance(tool_args, dict):
+ tool_args = {}
+
+ try:
+ args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
+ except:
+ args_str = "{}"
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ id=random_tool_call_id(),
+ function=FunctionCall(
+ name=tool_call.get("name", ""),
+ arguments=args_str,
+ ),
+ )
+ )
+
+ # Only return tools_called=True if all tool calls are complete
+ return ExtractedToolCallInformation(
+ tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
+ )
+
+ except Exception as e:
+ data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
+ return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: dict,
+ ) -> Union[DeltaMessage, None]:
+
+ if self.tool_call_start_token_id not in current_token_ids:
+ return DeltaMessage(content=delta_text)
+
+ if self.valid is not None and not self.valid:
+ return DeltaMessage(content=delta_text)
+
+ # Skip empty chunks
+ if len(delta_text.strip()) == 0:
+ return None
+
+ try:
+ delta = None
+ # Use buffer to accumulate delta_text content
+ self.buffer += delta_text
+
+ # Process the buffer content
+ if "" in delta_text:
+ if self.valid is None:
+ tool_call_begin = current_text.find(self.tool_call_start_token)
+ prefix = current_text[:tool_call_begin]
+ prefix = prefix.strip("\n")
+ if len(prefix) > 0 and not prefix.endswith(""):
+ self.valid = False
+ return DeltaMessage(content=delta_text)
+ self.valid = True
+ self.current_tool_id = (
+ max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
+ )
+ self.current_tool_name_sent = False
+ if len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+ data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
+
+ # 1. Try to parse the name field
+ if not self.current_tool_name_sent and '"name"' in self.buffer:
+ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
+ if name_match:
+ name = name_match.group(1)
+ if name:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ type="function",
+ id=random_tool_call_id(),
+ function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ # Delete the processed name part from the buffer
+ self.buffer = self.buffer[name_match.end() :]
+ self.current_tool_name_sent = True
+ return delta
+ # 2. Processing arguments field
+ if '"arguments"' in self.buffer:
+ args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
+ if args_match:
+ args_content = args_match.group(1)
+ try:
+ # Check if arguments field is complete by bracket matching
+ if "}}" in args_content:
+ matched_pos = -1
+ for i, ch in enumerate(delta_text):
+ if ch == "{":
+ self.bracket_counts["total_l"] += 1
+ elif ch == "}":
+ self.bracket_counts["total_r"] += 1
+
+ if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]:
+ matched_pos = i
+ break
+
+ if matched_pos >= 0:
+ # Clean up bracket counts for next tool call
+ truncate_text = delta_text[: matched_pos + 1]
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=truncate_text).model_dump(
+ exclude_none=True
+ ),
+ )
+ ]
+ )
+ self.buffer = self.buffer[args_match.end() :]
+ return delta
+ else:
+ # No complete match yet
+ return None
+ else:
+ # Return partial arguments
+ for ch in delta_text:
+ if ch == "{":
+ self.bracket_counts["total_l"] += 1
+ elif ch == "}":
+ self.bracket_counts["total_r"] += 1
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ return delta
+ except Exception as e:
+ data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
+ return None
+ if "" in self.buffer:
+ end_pos = self.buffer.find("")
+ self.buffer = self.buffer[end_pos + len("") :]
+
+ self.streamed_args_for_tool.append("")
+
+ return delta
+
+ except Exception as e:
+ data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
+ return None
diff --git a/fastdeploy/reasoning/__init__.py b/fastdeploy/reasoning/__init__.py
index 49c627895..c384e6f37 100644
--- a/fastdeploy/reasoning/__init__.py
+++ b/fastdeploy/reasoning/__init__.py
@@ -17,6 +17,7 @@
from fastdeploy.plugins import load_reasoning_parser_plugins
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
+from .ernie_45_vl_thinking_reasoning_parser import Ernie45VLThinkingReasoningParser
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
from .qwen3_reasoning_parsers import Qwen3ReasoningParser
@@ -27,6 +28,7 @@ __all__ = [
"ErnieVLReasoningParser",
"Qwen3ReasoningParser",
"ErnieX1ReasoningParser",
+ "Ernie45VLThinkingReasoningParser",
]
load_reasoning_parser_plugins()
diff --git a/fastdeploy/reasoning/ernie_45_vl_thinking_reasoning_parser.py b/fastdeploy/reasoning/ernie_45_vl_thinking_reasoning_parser.py
new file mode 100644
index 000000000..72b045a3d
--- /dev/null
+++ b/fastdeploy/reasoning/ernie_45_vl_thinking_reasoning_parser.py
@@ -0,0 +1,138 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from collections.abc import Sequence
+from typing import Optional, Union
+
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
+from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
+
+
+@ReasoningParserManager.register_module("erine-45-vl-thinking")
+class Ernie45VLThinkingReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for ernir_vl model.
+
+ The ernie_vl model uses ...... tokens to denote reasoning text
+ within its output. The model provides a strict switch to disable reasoning
+ output via the 'enable_thinking=False' parameter. This parser extracts the
+ reasoning content enclosed by and tokens from the model's
+ output.
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+ self.think_end_token = ""
+ self.tool_begin_token = ""
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ReasoningParser " "constructor during construction."
+ )
+
+ self.think_end_token_id = self.vocab.get(self.think_end_token)
+ self.tool_begin_token_id = self.vocab.get(self.tool_begin_token)
+ if self.tool_begin_token_id is None:
+ self.tool_begin_token_id = -1
+
+ if self.think_end_token_id is None:
+ raise RuntimeError("Test reasoning parser could not locate think end tokens in the tokenizer!")
+
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
+ return self.think_end_token_id in input_ids
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> Union[DeltaMessage, None]:
+ """
+ Extract reasoning content from a delta message.
+ Handles streaming output where previous + delta = current.
+ Uses token IDs for faster processing.
+ For text abcxyz:
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+ """
+ if self.think_end_token not in current_text:
+ return DeltaMessage(reasoning_content=delta_text)
+ # Skip single special tokens
+ if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
+ return None
+ if self._is_with_tool(current_text=current_text, current_token_ids=current_token_ids):
+ if self.think_end_token in delta_text:
+ think_begin = delta_text.find(self.think_end_token)
+ reasoning_content = delta_text[:think_begin]
+ return DeltaMessage(reasoning_content=reasoning_content)
+ return None
+ if self.think_end_token in delta_text:
+ reasoning_content, _, content = delta_text.partition(self.think_end_token)
+ striped_content = content.strip("\n")
+ if len(striped_content) == 0:
+ return DeltaMessage(reasoning_content=reasoning_content) if reasoning_content else None
+ return (
+ DeltaMessage(reasoning_content=reasoning_content, content=content)
+ if reasoning_content
+ else DeltaMessage(content=content)
+ )
+ think_end = current_text.find(self.think_end_token) + len(self.think_end_token)
+ suffix = current_text[think_end:]
+ striped_suffix = suffix.strip("\n")
+ if len(striped_suffix) == 0:
+ return None
+ return DeltaMessage(content=delta_text)
+
+ def extract_reasoning_content(
+ self, model_output: str, request: ChatCompletionRequest
+ ) -> tuple[Optional[str], Optional[str]]:
+ """
+ Extract reasoning content from the model output.
+
+ For text abcxyz:
+ - 'abc' goes to reasoning_content
+ - 'xyz' goes to content
+
+ Returns:
+ tuple[Optional[str], Optional[str]]: reasoning content and content
+ """
+
+ # Check if the model output contains the tokens.
+ if self.think_end_token not in model_output:
+ return model_output, ""
+ reasoning_content, _, content = model_output.partition(self.think_end_token)
+ if self.tool_begin_token in content:
+ prefix, _, _ = content.partition(self.tool_begin_token)
+ prefix_strip = prefix.lstrip("\n")
+ if len(prefix_strip) > 0:
+ return reasoning_content, content
+ return reasoning_content, ""
+ return reasoning_content, content
+
+ def _is_with_tool(self, current_text: str, current_token_ids: Sequence[int]) -> bool:
+ think_end_index = current_text.find(self.think_end_token)
+ think_end = think_end_index + len(self.think_end_token)
+ middle_str = current_text[think_end:]
+ if self.tool_begin_token_id in current_token_ids:
+ prefix, _, _ = middle_str.partition(self.tool_begin_token)
+ striped_prefix = prefix.strip("\n")
+ if len(striped_prefix) > 0:
+ return False
+ return True
+ return False
diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_45_vl_thinking_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_45_vl_thinking_tool_parser.py
new file mode 100644
index 000000000..c5676ce66
--- /dev/null
+++ b/tests/entrypoints/openai/tool_parsers/test_ernie_45_vl_thinking_tool_parser.py
@@ -0,0 +1,193 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import unittest
+from unittest.mock import patch
+
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
+from fastdeploy.entrypoints.openai.tool_parsers.ernie_45_vl_thinking_tool_parser import (
+ Ernie45VLThinkingToolParser,
+)
+
+
+class DummyTokenizer:
+ """Dummy tokenizer with minimal vocab for testing"""
+
+ def __init__(self):
+ self.vocab = {"": 1, "": 2}
+
+
+class TestErnie45VLThinkingToolParser(unittest.TestCase):
+ def setUp(self):
+ class DummyTokenizer:
+ def __init__(self):
+ self.vocab = {"": 1, "": 2}
+
+ def get_vocab(self):
+ return self.vocab
+
+ self.tokenizer = DummyTokenizer()
+ self.parser = Ernie45VLThinkingToolParser(tokenizer=self.tokenizer)
+ self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}])
+
+ # ---------------- Batch extraction tests ----------------
+
+ def test_extract_tool_calls_complete(self):
+ """Test normal extraction of complete tool_call JSON"""
+ output = '{"name": "get_weather", "arguments": {"location": "北京"}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertTrue(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_partial_arguments(self):
+ """Test partial extraction when arguments incomplete"""
+ output = '{"name": "get_weather", "arguments": {"location": "北"'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_no_toolcall(self):
+ """Test when no tool_call tags are present"""
+ output = "no tool call here"
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+
+ def test_extract_tool_calls_invalid_json(self):
+ """Test tool_call with badly formatted JSON triggers fallback parser"""
+ output = '"name": "get_weather", "arguments": {'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(result.tool_calls[0].function.name, "get_weather")
+
+ def test_extract_tool_calls_exception(self):
+ """Force exception to cover error branch"""
+ with patch(
+ "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom")
+ ):
+ output = '{"name": "get_weather", "arguments": {}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+
+ def test_extract_tool_calls_illegal(self):
+ output = 'abc{"name": "get_weather", "arguments": {"location": "北京"}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(
+ result.content,
+ 'abc{"name": "get_weather", "arguments": {"location": "北京"}}',
+ )
+ output = 'abc{"name": "get_weather", "arguments": {"location": "北京"}}'
+ result = self.parser.extract_tool_calls(output, self.dummy_request)
+ self.assertFalse(result.tools_called)
+ self.assertEqual(
+ result.content, 'abc{"name": "get_weather", "arguments": {"location": "北京"}}'
+ )
+
+ # ---------------- Streaming extraction tests ----------------
+
+ def test_streaming_no_toolcall(self):
+ """Streaming extraction returns normal DeltaMessage when no """
+ result = self.parser.extract_tool_calls_streaming(
+ "", "abc", "abc", [], [], [], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertIsNone(result.tool_calls)
+ self.assertEqual(result.content, "abc")
+
+ def test_streaming_skip_empty_chunk(self):
+ """Streaming extraction skips empty chunks"""
+ result = self.parser.extract_tool_calls_streaming(
+ "", "", " ", [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_new_toolcall_and_name(self):
+ """Streaming extraction detects new toolcall and extracts name"""
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "", '{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsNotNone(delta)
+ self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
+
+ def test_streaming_partial_arguments(self):
+ """Streaming extraction yields partial arguments deltas"""
+ text = '"arguments": {"location":'
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "" + text, text, [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(delta, DeltaMessage)
+ self.assertIn("arguments", delta.tool_calls[0].function.arguments)
+
+ def test_streaming_complete_arguments_and_end(self):
+ """Streaming extraction completes arguments with brackets matched and closes tool_call"""
+ text = '"arguments": {"location": "北京"}}'
+ delta = self.parser.extract_tool_calls_streaming(
+ "", "" + text, "" + text, [], [1], [1], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(delta, DeltaMessage)
+ # Also simulate closing tag
+ end_delta = self.parser.extract_tool_calls_streaming(
+ "" + text,
+ "" + text + "",
+ "",
+ [1],
+ [1, 2],
+ [2],
+ self.dummy_request.model_dump(),
+ )
+ self.assertIsNone(end_delta)
+
+ def test_streaming_no_tool_illegal(self):
+ result = self.parser.extract_tool_calls_streaming(
+ "", "abc", "abc", [], [], [], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertIsNone(result.tool_calls)
+ self.assertEqual(result.content, "abc")
+ result = self.parser.extract_tool_calls_streaming(
+ "", "abc", "abc", [], [], [], self.dummy_request.model_dump()
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertIsNone(result.tool_calls)
+ self.assertEqual(result.content, "abc")
+
+ def test_streaming_tool_with_reasoning(self):
+ delta = self.parser.extract_tool_calls_streaming(
+ "",
+ '{"name": "get_weather"',
+ '{"name": "get_weather"',
+ [],
+ [1],
+ [1],
+ self.dummy_request.model_dump(),
+ )
+ self.assertIsNotNone(delta)
+ self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
+ delta = self.parser.extract_tool_calls_streaming(
+ "",
+ '\n\n{"name": "get_weather"',
+ '\n\n{"name": "get_weather"',
+ [],
+ [1],
+ [1],
+ self.dummy_request.model_dump(),
+ )
+ self.assertIsNotNone(delta)
+ self.assertEqual(delta.tool_calls[0].function.name, "get_weather")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/reasoning/test_reasoning_parser.py b/tests/reasoning/test_reasoning_parser.py
index 90a48c899..9e06523b0 100644
--- a/tests/reasoning/test_reasoning_parser.py
+++ b/tests/reasoning/test_reasoning_parser.py
@@ -18,6 +18,9 @@ import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
+from fastdeploy.reasoning.ernie_45_vl_thinking_reasoning_parser import (
+ Ernie45VLThinkingReasoningParser,
+)
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
@@ -261,5 +264,166 @@ class TestErnieX1ReasoningParser(unittest.TestCase):
self.assertEqual(response, "line1\nline2\n")
+class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = DummyTokenizer()
+ self.parser = Ernie45VLThinkingReasoningParser(tokenizer=self.tokenizer)
+ self.test_request = ChatCompletionRequest(
+ model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
+ )
+
+ def test_streaming_non_reasoning(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="a",
+ delta_text="a",
+ previous_token_ids=[],
+ current_token_ids=[200],
+ delta_token_ids=[200],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.reasoning_content, "a")
+ self.assertIsNone(result.content)
+
+ def test_streaming_with_reasoning(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="ab",
+ current_text="ab",
+ delta_text="",
+ previous_token_ids=[200, 201],
+ current_token_ids=[200, 201, 100],
+ delta_token_ids=[100],
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_with_reasoning_and_content(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="ab",
+ current_text="ab\n\ncd",
+ delta_text="\n\ncd",
+ previous_token_ids=[200, 201],
+ current_token_ids=[200, 201, 100, 300, 400],
+ delta_token_ids=[100, 300, 400],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertIsNone(result.reasoning_content)
+ self.assertEqual(result.content, "\n\ncd")
+
+ def test_streaming_with_reasoning_new_line(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abc\n\n",
+ delta_text="\n\n",
+ previous_token_ids=[200, 201, 202],
+ current_token_ids=[200, 201, 202, 100],
+ delta_token_ids=[100],
+ )
+ self.assertIsNone(result)
+
+ def test_streaming_with_reasoning_and_tool(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abc\n\n",
+ delta_text="\n\n",
+ previous_token_ids=[200, 201, 202],
+ current_token_ids=[200, 201, 202, 100, 200, 101],
+ delta_token_ids=[100, 200, 101],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.reasoning_content, "")
+
+ def test_streaming_with_reasoning_and_illegal_tool(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abc\n\nhello",
+ delta_text="\n\nhello",
+ previous_token_ids=[200, 201, 202],
+ current_token_ids=[200, 201, 202, 100, 200, 101],
+ delta_token_ids=[109, 200, 101],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.content, "\n\nhello")
+
+ def test_streaming_with_reasoning_no_tool(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="abc",
+ current_text="abchello\nworld",
+ delta_text="hello\nworld",
+ previous_token_ids=[200, 201, 202],
+ current_token_ids=[200, 201, 202, 100, 200, 110],
+ delta_token_ids=[100, 200, 110],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.reasoning_content, "hello")
+ self.assertEqual(result.content, "\nworld")
+
+ def test_streaming_reasoning_previous_no_tool(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="\nhello",
+ delta_text="\nhello",
+ previous_token_ids=[100],
+ current_token_ids=[100, 110, 111],
+ delta_token_ids=[110, 111],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertIsNone(result.reasoning_content)
+ self.assertEqual(result.content, "\nhello")
+
+ def test_streaming_no_reasoning_previous_tool(self):
+ result = self.parser.extract_reasoning_content_streaming(
+ previous_text="",
+ current_text="hello",
+ delta_text="hello",
+ previous_token_ids=[101],
+ current_token_ids=[101, 110],
+ delta_token_ids=[110],
+ )
+ self.assertIsInstance(result, DeltaMessage)
+ self.assertEqual(result.reasoning_content, "hello")
+
+ def test_batch_no_think_end(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="direct response", request=self.test_request
+ )
+ self.assertEqual(reasoning, "direct response")
+ self.assertEqual(content, "")
+
+ def test_batch_no_think_end_with_tool(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="direct responseabc", request=self.test_request
+ )
+ self.assertEqual(reasoning, "direct responseabc")
+ self.assertEqual(content, "")
+
+ def test_batch_think_end_normal_content(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="reasoning\nresponse", request=self.test_request
+ )
+ self.assertEqual(reasoning, "reasoning")
+ self.assertEqual(content, "\nresponse")
+
+ def test_batch_think_end_with_tool(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="reasoning\ntool params", request=self.test_request
+ )
+ self.assertEqual(reasoning, "reasoning")
+ self.assertEqual(content, "")
+
+ def test_batch_think_end_with_illegal_tool(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="reasoning\nABC\ntool params", request=self.test_request
+ )
+ self.assertEqual(reasoning, "reasoning")
+ self.assertEqual(content, "\nABC\ntool params")
+
+ def test_batch_think_end_content_with_newline(self):
+ reasoning, content = self.parser.extract_reasoning_content(
+ model_output="reasoning\n\n actual response", request=self.test_request
+ )
+ self.assertEqual(reasoning, "reasoning")
+ self.assertEqual(content, "\n\n actual response")
+
+
if __name__ == "__main__":
unittest.main()