diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index e5df1a2e1..9b0c7b9cb 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -14,10 +14,18 @@ import json import re +import uuid from collections.abc import Sequence from typing import Union -from fastdeploy.entrypoints.chat_utils import random_tool_call_id +import partial_json_parser + + +def random_tool_call_id() -> str: + """Generate a random tool call ID""" + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + from fastdeploy.entrypoints.openai.protocol import ( ChatCompletionRequest, DeltaFunctionCall, @@ -53,8 +61,6 @@ class ErnieX1ToolParser(ToolParser): self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" - self.tool_call_regex = re.compile(r"(.*?)|(.*)", re.DOTALL) - self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: @@ -67,9 +73,7 @@ class ErnieX1ToolParser(ToolParser): "The model tokenizer must be passed to the ToolCallParser constructor during construction." ) - def extract_tool_calls( - self, model_output: str, request: ChatCompletionRequest, model_status: str - ) -> ExtractedToolCallInformation: + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Supports XML-style formats with newlines: @@ -81,31 +85,144 @@ class ErnieX1ToolParser(ToolParser): 3. Only name and arguments field without content: {"name": "get_weather", "argume """ - extract_content = model_output - if model_status == "tool_call_start": - extract_content = "" + model_output try: - if self.tool_call_start_token not in extract_content: - return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) - function_call_tuples = self.tool_call_regex.findall(extract_content) + tool_calls = [] - raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] + # Check for invalid tags before tool calls + if re.search(r"[\s\S]*?\s*(?=)", model_output): + data_processor_logger.error("Invalid format: tags found before ") + return ExtractedToolCallInformation(tools_called=False, content=model_output) - tool_calls = [ - ToolCall( - type="function", - function=FunctionCall( - name=function_call["name"], - # function call args are JSON but as a string - arguments=json.dumps(function_call["arguments"], ensure_ascii=False), - ), + function_call_arr = [] + remaining_text = model_output + + while True: + # 查找下一个tool_call块 + tool_call_pos = remaining_text.find("") + if tool_call_pos == -1: + break + + # 提取tool_call开始位置后的内容 + tool_content_start = tool_call_pos + len("") + tool_content_end = remaining_text.find("", tool_content_start) + + tool_json = "" + if tool_content_end == -1: + # 处理未闭合的tool_call块(截断情况) + tool_json = remaining_text[tool_content_start:].strip() + remaining_text = "" # 没有更多内容需要处理 + else: + # 处理完整的tool_call块 + tool_json = remaining_text[tool_content_start:tool_content_end].strip() + remaining_text = remaining_text[tool_content_end + len("") :] + + if not tool_json: + continue + + # 处理JSON内容 + tool_json = tool_json.strip() + if not tool_json.startswith("{"): + tool_json = "{" + tool_json + if not tool_json.endswith("}"): + tool_json = tool_json + "}" + + try: + # 首先尝试标准JSON解析 + try: + tool_data = json.loads(tool_json) + + if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data: + function_call_arr.append( + { + "name": tool_data["name"], + "arguments": tool_data["arguments"], + "_is_complete": True, # 明确标记为完整解析 + } + ) + continue + except json.JSONDecodeError: + pass + + # 标准解析失败时尝试partial_json_parser + from partial_json_parser.core.options import Allow + + try: + tool_data = {} + flags = Allow.ALL & ~Allow.STR + + # 解析name字段 + name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) + if name_match: + tool_data["name"] = name_match.group(1) + + # 解析arguments字段 + args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) + if args_match: + try: + tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags) + except: + tool_data["arguments"] = None + + if isinstance(tool_data, dict): + function_call_arr.append( + { + "name": tool_data.get("name", ""), + "arguments": tool_data.get("arguments", {}), + "_is_partial": True, # 标记为部分解析 + } + ) + except Exception as e: + data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") + continue + except Exception as e: + data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") + continue + + if not function_call_arr: + data_processor_logger.error("No valid tool calls found") + return ExtractedToolCallInformation(tools_called=False, content=model_output) + + tool_calls = [] + all_complete = True # 初始设为True,只要有一个不完整就变为False + + for tool_call in function_call_arr: + # 记录工具调用解析状态 + is_complete = tool_call.get("_is_complete", False) + is_partial = tool_call.get("_is_partial", False) + + # 只要有一个不完整就认为整体不完整 + if not is_complete or is_partial: + all_complete = False + + # 处理参数序列化 + tool_args = tool_call.get("arguments", {}) + if not isinstance(tool_args, dict): + tool_args = {} + + try: + args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}" + except: + args_str = "{}" + + tool_calls.append( + ToolCall( + type="function", + id=random_tool_call_id(), + function=FunctionCall( + name=tool_call.get("name", ""), + arguments=args_str, + ), + ) ) - for function_call in raw_function_calls - ] - return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content="") - except Exception: - data_processor_logger.error("Error in extracting tool call from response.") - return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) + + # 只有当所有工具调用都明确标记为complete时才返回tools_called=True + return ExtractedToolCallInformation( + tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" + ) + + except Exception as e: + data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}") + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) def extract_tool_calls_streaming( self, @@ -116,7 +233,6 @@ class ErnieX1ToolParser(ToolParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: dict, - model_status: str, ) -> Union[DeltaMessage, None]: if self.tool_call_start_token_id not in current_token_ids: