diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
index e5df1a2e1..9b0c7b9cb 100644
--- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
@@ -14,10 +14,18 @@
import json
import re
+import uuid
from collections.abc import Sequence
from typing import Union
-from fastdeploy.entrypoints.chat_utils import random_tool_call_id
+import partial_json_parser
+
+
+def random_tool_call_id() -> str:
+ """Generate a random tool call ID"""
+ return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+
+
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
@@ -53,8 +61,6 @@ class ErnieX1ToolParser(ToolParser):
self.tool_call_start_token: str = ""
self.tool_call_end_token: str = ""
- self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL)
-
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
@@ -67,9 +73,7 @@ class ErnieX1ToolParser(ToolParser):
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
)
- def extract_tool_calls(
- self, model_output: str, request: ChatCompletionRequest, model_status: str
- ) -> ExtractedToolCallInformation:
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
Supports XML-style formats with newlines:
@@ -81,31 +85,144 @@ class ErnieX1ToolParser(ToolParser):
3. Only name and arguments field without content: {"name": "get_weather", "argume
"""
- extract_content = model_output
- if model_status == "tool_call_start":
- extract_content = "" + model_output
try:
- if self.tool_call_start_token not in extract_content:
- return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
- function_call_tuples = self.tool_call_regex.findall(extract_content)
+ tool_calls = []
- raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples]
+ # Check for invalid tags before tool calls
+ if re.search(r"[\s\S]*?\s*(?=)", model_output):
+ data_processor_logger.error("Invalid format: tags found before ")
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
- tool_calls = [
- ToolCall(
- type="function",
- function=FunctionCall(
- name=function_call["name"],
- # function call args are JSON but as a string
- arguments=json.dumps(function_call["arguments"], ensure_ascii=False),
- ),
+ function_call_arr = []
+ remaining_text = model_output
+
+ while True:
+ # 查找下一个tool_call块
+ tool_call_pos = remaining_text.find("")
+ if tool_call_pos == -1:
+ break
+
+ # 提取tool_call开始位置后的内容
+ tool_content_start = tool_call_pos + len("")
+ tool_content_end = remaining_text.find("", tool_content_start)
+
+ tool_json = ""
+ if tool_content_end == -1:
+ # 处理未闭合的tool_call块(截断情况)
+ tool_json = remaining_text[tool_content_start:].strip()
+ remaining_text = "" # 没有更多内容需要处理
+ else:
+ # 处理完整的tool_call块
+ tool_json = remaining_text[tool_content_start:tool_content_end].strip()
+ remaining_text = remaining_text[tool_content_end + len("") :]
+
+ if not tool_json:
+ continue
+
+ # 处理JSON内容
+ tool_json = tool_json.strip()
+ if not tool_json.startswith("{"):
+ tool_json = "{" + tool_json
+ if not tool_json.endswith("}"):
+ tool_json = tool_json + "}"
+
+ try:
+ # 首先尝试标准JSON解析
+ try:
+ tool_data = json.loads(tool_json)
+
+ if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
+ function_call_arr.append(
+ {
+ "name": tool_data["name"],
+ "arguments": tool_data["arguments"],
+ "_is_complete": True, # 明确标记为完整解析
+ }
+ )
+ continue
+ except json.JSONDecodeError:
+ pass
+
+ # 标准解析失败时尝试partial_json_parser
+ from partial_json_parser.core.options import Allow
+
+ try:
+ tool_data = {}
+ flags = Allow.ALL & ~Allow.STR
+
+ # 解析name字段
+ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
+ if name_match:
+ tool_data["name"] = name_match.group(1)
+
+ # 解析arguments字段
+ args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
+ if args_match:
+ try:
+ tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
+ except:
+ tool_data["arguments"] = None
+
+ if isinstance(tool_data, dict):
+ function_call_arr.append(
+ {
+ "name": tool_data.get("name", ""),
+ "arguments": tool_data.get("arguments", {}),
+ "_is_partial": True, # 标记为部分解析
+ }
+ )
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+
+ if not function_call_arr:
+ data_processor_logger.error("No valid tool calls found")
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
+
+ tool_calls = []
+ all_complete = True # 初始设为True,只要有一个不完整就变为False
+
+ for tool_call in function_call_arr:
+ # 记录工具调用解析状态
+ is_complete = tool_call.get("_is_complete", False)
+ is_partial = tool_call.get("_is_partial", False)
+
+ # 只要有一个不完整就认为整体不完整
+ if not is_complete or is_partial:
+ all_complete = False
+
+ # 处理参数序列化
+ tool_args = tool_call.get("arguments", {})
+ if not isinstance(tool_args, dict):
+ tool_args = {}
+
+ try:
+ args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
+ except:
+ args_str = "{}"
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ id=random_tool_call_id(),
+ function=FunctionCall(
+ name=tool_call.get("name", ""),
+ arguments=args_str,
+ ),
+ )
)
- for function_call in raw_function_calls
- ]
- return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="")
- except Exception:
- data_processor_logger.error("Error in extracting tool call from response.")
- return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
+
+ # 只有当所有工具调用都明确标记为complete时才返回tools_called=True
+ return ExtractedToolCallInformation(
+ tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
+ )
+
+ except Exception as e:
+ data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
+ return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
def extract_tool_calls_streaming(
self,
@@ -116,7 +233,6 @@ class ErnieX1ToolParser(ToolParser):
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: dict,
- model_status: str,
) -> Union[DeltaMessage, None]:
if self.tool_call_start_token_id not in current_token_ids: