add Tool Parser (#3272)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* add tool-parser

* add tool-parser

* add tool parser

* add tool parser

* fix

* add offline

* add offline

* fix

* parsers:tool&reasoning

* 修改tool parser名称·

* update

* fix reasoning-parser

* add requirements

* fix finish reason

* fix

* fix reasoning-parser

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
This commit is contained in:
luukunn
2025-08-13 01:06:55 +08:00
committed by GitHub
parent 2d1a4cacdf
commit eda83ca672
23 changed files with 1056 additions and 38 deletions

View File

@@ -0,0 +1,24 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from .abstract_tool_parser import ToolParser, ToolParserManager
from .ernie_x1_tool_parser import ErnieX1ToolParser
__all__ = [
"ToolParser",
"ToolParserManager",
"ErnieX1ToolParser",
]

View File

@@ -0,0 +1,159 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from collections.abc import Sequence
from functools import cached_property
from typing import Callable, Optional, Union
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
)
from fastdeploy.utils import data_processor_logger, import_from_path, is_list_of
class ToolParser:
"""
Abstract ToolParser class that should not be used directly. Provided
properties and methods should be used in
derived classes.
"""
def __init__(self, tokenizer):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> dict[str, int]:
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""
Static method that used to adjust the request parameters.
"""
return request
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Static method that should be implemented for extracting tool calls from
a complete model-generated string.
Used for non-streaming responses where we have the entire model response
available before sending to the client.
Static because it's stateless.
"""
raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!")
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
"""
Instance method that should be implemented for extracting tool calls
from an incomplete response; for use when handling tool calls and
streaming. Has to be an instance method because it requires state -
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!")
class ToolParserManager:
tool_parsers: dict[str, type] = {}
@classmethod
def get_tool_parser(cls, name) -> type:
"""
Get tool parser by name which is registered by `register_module`.
Raise a KeyError exception if the name is not registered.
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(
cls, module: type, module_name: Optional[Union[str, list[str]]] = None, force: bool = True
) -> None:
if not issubclass(module, ToolParser):
raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}")
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f"{name} is already registered " f"at {existed_module.__module__}")
cls.tool_parsers[name] = module
@classmethod
def register_module(
cls, name: Optional[Union[str, list[str]]] = None, force: bool = True, module: Union[type, None] = None
) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_list_of(name, str)):
raise TypeError("name must be None, an instance of str, or a sequence of str, " f"but got {type(name)}")
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
@classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""
Import a user-defined tool parser by the path of the tool parser define
file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
data_processor_logger.exception("Failed to load module '%s' from %s.", module_name, plugin_path)
return

View File

@@ -0,0 +1,320 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
import uuid
from collections.abc import Sequence
from typing import Union
import partial_json_parser
def random_tool_call_id() -> str:
"""Generate a random tool call ID"""
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
from fastdeploy.utils import data_processor_logger
@ToolParserManager.register_module("ernie_x1")
class ErnieX1ToolParser(ToolParser):
"""
Tool parser for Ernie model version 4.5.1.
This parser handles tool calls with newline formats.
"""
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
)
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
Supports XML-style formats with newlines:
- XML format: <think>\n...\n</think>\n\n\n<tool_call>\n{...}\n</tool_call>\n...
Handles boundary cases:
1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
2. Only partial name: {"name": "get_we
3. Only name and arguments field without content: {"name": "get_weather", "argume
"""
try:
tool_calls = []
# Check for invalid <response> tags before tool calls
if re.search(r"<response>[\s\S]*?</response>\s*(?=<tool_call>)", model_output):
data_processor_logger.error("Invalid format: <response> tags found before <tool_call>")
return ExtractedToolCallInformation(tools_called=False, content=model_output)
function_call_arr = []
remaining_text = model_output
while True:
# 查找下一个tool_call块
tool_call_pos = remaining_text.find("<tool_call>")
if tool_call_pos == -1:
break
# 提取tool_call开始位置后的内容
tool_content_start = tool_call_pos + len("<tool_call>")
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
tool_json = ""
if tool_content_end == -1:
# 处理未闭合的tool_call块截断情况
tool_json = remaining_text[tool_content_start:].strip()
remaining_text = "" # 没有更多内容需要处理
else:
# 处理完整的tool_call块
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
if not tool_json:
continue
# 处理JSON内容
tool_json = tool_json.strip()
if not tool_json.startswith("{"):
tool_json = "{" + tool_json
if not tool_json.endswith("}"):
tool_json = tool_json + "}"
try:
# 首先尝试标准JSON解析
try:
tool_data = json.loads(tool_json)
if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
function_call_arr.append(
{
"name": tool_data["name"],
"arguments": tool_data["arguments"],
"_is_complete": True, # 明确标记为完整解析
}
)
continue
except json.JSONDecodeError:
pass
# 标准解析失败时尝试partial_json_parser
from partial_json_parser.core.options import Allow
try:
tool_data = {}
flags = Allow.ALL & ~Allow.STR
# 解析name字段
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
if name_match:
tool_data["name"] = name_match.group(1)
# 解析arguments字段
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
if args_match:
try:
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
except:
tool_data["arguments"] = None
if isinstance(tool_data, dict):
function_call_arr.append(
{
"name": tool_data.get("name", ""),
"arguments": tool_data.get("arguments", {}),
"_is_partial": True, # 标记为部分解析
}
)
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue
except Exception as e:
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue
if not function_call_arr:
data_processor_logger.error("No valid tool calls found")
return ExtractedToolCallInformation(tools_called=False, content=model_output)
tool_calls = []
all_complete = True # 初始设为True只要有一个不完整就变为False
for tool_call in function_call_arr:
# 记录工具调用解析状态
is_complete = tool_call.get("_is_complete", False)
is_partial = tool_call.get("_is_partial", False)
# 只要有一个不完整就认为整体不完整
if not is_complete or is_partial:
all_complete = False
# 处理参数序列化
tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict):
tool_args = {}
try:
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
except:
args_str = "{}"
tool_calls.append(
ToolCall(
type="function",
id=random_tool_call_id(),
function=FunctionCall(
name=tool_call.get("name", ""),
arguments=args_str,
),
)
)
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True
return ExtractedToolCallInformation(
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
)
except Exception as e:
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: dict,
) -> Union[DeltaMessage, None]:
# 忽略空chunk
if len(delta_text.strip()) == 0:
return None
try:
delta = None
# 使用buffer累积delta_text内容
self.buffer += delta_text
# 处理增量中的新tool_call开始
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text:
self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
)
self.current_tool_name_sent = False
if len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
# 增量解析逻辑
# 1. 尝试解析name字段
if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
if name_match:
name = name_match.group(1)
if name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
)
]
)
print("delta name:", delta)
# 删除已处理的name部分
self.buffer = self.buffer[name_match.end() :]
self.current_tool_name_sent = True
return delta
# 2. 尝试解析arguments字段
if '"arguments"' in self.buffer:
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match:
args_content = args_match.group(1)
# 处理多余的大括号
open_braces = args_content.count("{")
close_braces = args_content.count("}")
if close_braces > open_braces:
args_content = args_content[: args_content.rfind("}")]
try:
# 增量解析arguments
parsed_args = json.loads(args_content)
if isinstance(parsed_args, dict):
args_json = json.dumps(parsed_args, ensure_ascii=False)
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
exclude_none=True
),
)
]
)
print("delta argument:", delta)
# 删除已处理部分
processed_pos = args_match.start() + len('"arguments":')
self.buffer = (
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
)
self.streamed_args_for_tool[self.current_tool_id] = args_json
return delta
except Exception as e:
data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
if "</tool_call>" in self.buffer:
end_pos = self.buffer.find("</tool_call>")
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
# 完成当前工具调用处理
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return delta
except Exception as e:
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
return None

View File

@@ -0,0 +1,137 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
from json import JSONDecodeError, JSONDecoder
from typing import Any
import partial_json_parser
from partial_json_parser.core.options import Allow
def find_common_prefix(s1: str, s2: str) -> str:
"""
Finds a common prefix that is shared between two strings, if there is one.
Order of arguments is NOT important.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely.
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
'{"fruit": "ap'
"""
prefix = ""
min_length = min(len(s1), len(s2))
for i in range(0, min_length):
if s1[i] == s2[i]:
prefix += s1[i]
else:
break
return prefix
def find_common_suffix(s1: str, s2: str) -> str:
"""
Finds a common suffix shared between two strings, if there is one. Order of
arguments is NOT important.
Stops when the suffix ends OR it hits an alphanumeric character
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
"""
suffix = ""
min_length = min(len(s1), len(s2))
for i in range(1, min_length + 1):
if s1[-i] == s2[-i] and not s1[-i].isalnum():
suffix = s1[-i] + suffix
else:
break
return suffix
def extract_intermediate_diff(curr: str, old: str) -> str:
"""
Given two strings, extract the difference in the middle between two strings
that are known to have a common prefix and/or suffix.
This function is provided as a UTILITY for extracting information from JSON
generated by partial_json_parser, to help in ensuring that the right tokens
are returned in streaming, so that close-quotes, close-brackets and
close-braces are not returned prematurely. The order of arguments IS
important - the new version of the partially-parsed JSON must be the first
argument, and the secnod argument must be from the previous generation.
What it returns, is tokens that should be streamed to the client.
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
-> 'ple'
"""
suffix = find_common_suffix(curr, old)
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
prefix = find_common_prefix(curr, old)
diff = curr
if len(suffix):
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
if len(prefix):
# replace the prefix only once in case it's mirrored
diff = diff.replace(prefix, "", 1)
return diff
def find_all_indices(string: str, substring: str) -> list[int]:
"""
Find all (starting) indices of a substring in a given string. Useful for
tool call extraction
"""
indices = []
index = -1
while True:
index = string.find(substring, index + 1)
if index == -1:
break
indices.append(index)
return indices
# partial_json_parser doesn't support extra data and
# JSONDecoder.raw_decode doesn't support partial JSON
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
try:
return (partial_json_parser.loads(input_str, flags), len(input_str))
except JSONDecodeError as e:
if "Extra data" in e.msg:
dec = JSONDecoder()
return dec.raw_decode(input_str)
raise
def is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
return True
except JSONDecodeError:
return False
def consume_space(i: int, s: str) -> int:
while i < len(s) and s[i].isspace():
i += 1
return i