diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 835d3eb4d..054077c13 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -95,6 +95,14 @@ class EngineArgs: """ specifies the reasoning parser to use for extracting reasoning content from the model output """ + tool_call_parser: str = None + """ + specifies the tool call parser to use for extracting tool call from the model output + """ + tool_parser_plugin: str = None + """ + tool parser plugin used to register user defined tool parsers + """ enable_mm: bool = False """ Flags to enable multi-modal model @@ -423,6 +431,18 @@ class EngineArgs: help="Flag specifies the reasoning parser to use for extracting " "reasoning content from the model output", ) + model_group.add_argument( + "--tool-call-parser", + type=str, + default=EngineArgs.tool_call_parser, + help="Flag specifies the tool call parser to use for extracting" "tool call from the model output", + ) + model_group.add_argument( + "--tool-parser-plugin", + type=str, + default=EngineArgs.tool_parser_plugin, + help="tool parser plugin used to register user defined tool parsers", + ) model_group.add_argument( "--speculative-config", type=json.loads, @@ -913,6 +933,7 @@ class EngineArgs: mm_processor_kwargs=self.mm_processor_kwargs, enable_mm=self.enable_mm, reasoning_parser=self.reasoning_parser, + tool_parser=self.tool_call_parser, splitwise_role=self.splitwise_role, innode_prefill_ports=self.innode_prefill_ports, max_num_partial_prefills=self.max_num_partial_prefills, diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index f6303d7b3..1a9f2d2a1 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -85,6 +85,7 @@ class Config: max_long_partial_prefills: int = 1, long_prefill_token_threshold: int = 0, reasoning_parser: str = None, + tool_parser: str = None, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, enable_logprob: bool = False, @@ -165,6 +166,7 @@ class Config: self.max_long_partial_prefills = max_long_partial_prefills self.long_prefill_token_threshold = long_prefill_token_threshold self.reasoning_parser = reasoning_parser + self.tool_parser = tool_parser self.graph_optimization_config = graph_optimization_config self.early_stop_config = early_stop_config self.guided_decoding_backend = guided_decoding_backend diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index fa9fa6175..16a89932d 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -106,6 +106,7 @@ class LLMEngine: cfg.limit_mm_per_prompt, cfg.mm_processor_kwargs, cfg.enable_mm, + cfg.tool_parser, ) self.start_queue_service() diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index acf717547..b9fa895e6 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union import numpy as np from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.openai.protocol import ToolCall from fastdeploy.utils import data_processor_logger from fastdeploy.worker.output import LogprobsLists, SampleLogprobs @@ -249,6 +250,7 @@ class CompletionOutput: draft_token_ids: list[int] = None text: Optional[str] = None reasoning_content: Optional[str] = None + tool_calls: Optional[ToolCall] = None def to_dict(self): """ diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py index 4f7357e11..059ecee01 100644 --- a/fastdeploy/entrypoints/chat_utils.py +++ b/fastdeploy/entrypoints/chat_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import uuid from copy import deepcopy from typing import List, Literal, Union from urllib.parse import urlparse @@ -156,3 +157,7 @@ def parse_chat_messages(messages): conversation.append({"role": role, "content": parsed_content}) return conversation + + +def random_tool_call_id() -> str: + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 12d14f7e1..e7edacb26 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -45,6 +45,7 @@ class EngineClient: data_parallel_size=1, enable_logprob=False, workers=1, + tool_parser=None, ): input_processor = InputPreprocessor( tokenizer, @@ -52,6 +53,7 @@ class EngineClient: limit_mm_per_prompt, mm_processor_kwargs, enable_mm, + tool_parser, ) self.enable_logprob = enable_logprob self.enable_mm = enable_mm diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 3e150abf2..c744921ba 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -28,6 +28,7 @@ from tqdm import tqdm from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.engine import LLMEngine from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager # from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam from fastdeploy.utils import llm_logger, retrive_model_from_server @@ -73,6 +74,9 @@ class LLM: **kwargs, ): model = retrive_model_from_server(model, revision) + tool_parser_plugin = kwargs.get("tool_parser_plugin") + if tool_parser_plugin: + ToolParserManager.import_tool_parser(tool_parser_plugin) engine_args = EngineArgs( model=model, tokenizer=tokenizer, diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2f501b2ef..53168abc0 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -41,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import ( ) from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion +from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.metrics.metrics import ( EXCLUDE_LABELS, cleanup_prometheus_files, @@ -73,7 +74,8 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model, args.revision) - +if args.tool_parser_plugin: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) llm_engine = None @@ -126,6 +128,7 @@ async def lifespan(app: FastAPI): args.data_parallel_size, args.enable_logprob, args.workers, + args.tool_call_parser, ) app.state.dynamic_load_weight = args.dynamic_load_weight chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time) diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 678ae8dd0..2049fb971 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -72,7 +72,6 @@ class ToolCall(BaseModel): id: str = None type: Literal["function"] = "function" function: FunctionCall - index: int class DeltaFunctionCall(BaseModel): @@ -96,6 +95,18 @@ class DeltaToolCall(BaseModel): function: Optional[DeltaFunctionCall] = None +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: Optional[list[ToolCall]] = None + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + class FunctionDefinition(BaseModel): """ Function definition. diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 536cd7d80..1005bae0e 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -141,6 +141,7 @@ class OpenAIServingChat: previous_num_tokens = 0 num_prompt_tokens = 0 num_choices = 1 + tool_called = False max_streaming_response_tokens = ( request.max_streaming_response_tokens if request.max_streaming_response_tokens is not None @@ -245,20 +246,28 @@ class OpenAIServingChat: output = res["outputs"] delta_text = output["text"] output_top_logprobs = output["top_logprobs"] + previous_num_tokens += len(output["token_ids"]) logprobs_res: Optional[LogProbs] = None if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_chat_logprobs( output_top_logprobs, request.logprobs, request.top_logprobs ) - - previous_num_tokens += len(output["token_ids"]) - delta_message = DeltaMessage( - content=delta_text, - reasoning_content=output.get("reasoning_content"), - prompt_token_ids=None, - completion_token_ids=None, - tool_calls=output.get("tool_call_content", []), - ) + if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: + tool_delta_message = output["tool_delta_message"] + if tool_delta_message is None: + continue + delta_message = tool_delta_message + delta_message.reasoning_content = output.get("reasoning_content") + if delta_message.tool_calls: + tool_called = True + else: + delta_message = DeltaMessage( + content=delta_text, + reasoning_content=output.get("reasoning_content"), + prompt_token_ids=None, + completion_token_ids=None, + tool_calls=None, + ) choice = ChatCompletionResponseStreamChoice( index=0, @@ -276,10 +285,7 @@ class OpenAIServingChat: max_tokens = request.max_completion_tokens or request.max_tokens if has_no_token_limit or previous_num_tokens != max_tokens: choice.finish_reason = "stop" - if ( - self.engine_client.reasoning_parser == "ernie_x1" - and output.get("finish_reason", "") == "tool_calls" - ): + if tool_called: choice.finish_reason = "tool_calls" else: choice.finish_reason = "length" @@ -419,7 +425,7 @@ class OpenAIServingChat: role="assistant", content=output["text"], reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call_content"), + tool_calls=output.get("tool_call"), prompt_token_ids=prompt_token_ids if request.return_token_ids else None, completion_token_ids=completion_token_ids if request.return_token_ids else None, text_after_process=text_after_process if request.return_token_ids else None, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index cec597f78..1e8ad0f86 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -240,9 +240,9 @@ class OpenAIServingCompletion: dealer.close() self.engine_client.semaphore.release() - def calc_finish_reason(self, max_tokens, token_num, output): + def calc_finish_reason(self, max_tokens, token_num, output, tool_called): if max_tokens is None or token_num != max_tokens: - if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls": + if tool_called or output.get("tool_call"): return "tool_calls" else: return "stop" @@ -271,6 +271,7 @@ class OpenAIServingCompletion: output_tokens = [0] * num_choices inference_start_time = [0] * num_choices first_iteration = [True] * num_choices + tool_called = False max_streaming_response_tokens = ( request.max_streaming_response_tokens if request.max_streaming_response_tokens is not None @@ -342,24 +343,41 @@ class OpenAIServingCompletion: if request.logprobs and output_top_logprobs is not None: logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) - choices.append( - CompletionResponseStreamChoice( + output_tokens[idx] += 1 + if self.engine_client.data_processor.tool_parser_obj and not res["finished"]: + tool_delta_message = output["tool_delta_message"] + if tool_delta_message is None: + continue + delta_message = CompletionResponseStreamChoice( index=idx, text=output["text"], - prompt_token_ids=None, completion_token_ids=output.get("token_ids") if request.return_token_ids else None, - raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, - tool_calls=output.get("tool_call_content"), + tool_calls=tool_delta_message.tool_calls, reasoning_content=output.get("reasoning_content"), arrival_time=arrival_time, logprobs=logprobs_res, ) - ) + if tool_delta_message.tool_calls: + tool_called = True + else: + delta_message = CompletionResponseStreamChoice( + index=idx, + text=output["text"], + prompt_token_ids=None, + completion_token_ids=output.get("token_ids") if request.return_token_ids else None, + tool_calls=None, + raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, + reasoning_content=output.get("reasoning_content"), + arrival_time=arrival_time, + logprobs=logprobs_res, + ) + + choices.append(delta_message) output_tokens[idx] += 1 if res["finished"]: choices[-1].finish_reason = self.calc_finish_reason( - request.max_tokens, output_tokens[idx], output + request.max_tokens, output_tokens[idx], output, tool_called ) send_idx = output.get("send_idx") # 只有当 send_idx 明确为 0 时才记录日志 @@ -458,7 +476,7 @@ class OpenAIServingCompletion: token_ids = output["token_ids"] output_text = output["text"] - finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output) + finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False) choice_data = CompletionResponseChoice( token_ids=token_ids, @@ -469,7 +487,7 @@ class OpenAIServingCompletion: raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, text_after_process=text_after_process_list[idx] if request.return_token_ids else None, reasoning_content=output.get("reasoning_content"), - tool_calls=output.get("tool_call_content"), + tool_calls=output.get("tool_call"), logprobs=aggregated_logprobs, finish_reason=finish_reason, ) diff --git a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 000000000..2078a8c9f --- /dev/null +++ b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,24 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from .abstract_tool_parser import ToolParser, ToolParserManager +from .ernie_x1_tool_parser import ErnieX1ToolParser + +__all__ = [ + "ToolParser", + "ToolParserManager", + "ErnieX1ToolParser", +] diff --git a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 000000000..d6ac8f81a --- /dev/null +++ b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,159 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +from collections.abc import Sequence +from functools import cached_property +from typing import Callable, Optional, Union + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation, +) +from fastdeploy.utils import data_processor_logger, import_from_path, is_list_of + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer): + self.prev_tool_call_arr: list[dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") + + +class ToolParserManager: + tool_parsers: dict[str, type] = {} + + @classmethod + def get_tool_parser(cls, name) -> type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module( + cls, module: type, module_name: Optional[Union[str, list[str]]] = None, force: bool = True + ) -> None: + if not issubclass(module, ToolParser): + raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}") + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f"{name} is already registered " f"at {existed_module.__module__}") + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, name: Optional[Union[str, list[str]]] = None, force: bool = True, module: Union[type, None] = None + ) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + + # raise the error ahead of time + if not (name is None or isinstance(name, str) or is_list_of(name, str)): + raise TypeError("name must be None, an instance of str, or a sequence of str, " f"but got {type(name)}") + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + data_processor_logger.exception("Failed to load module '%s' from %s.", module_name, plugin_path) + return diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py new file mode 100644 index 000000000..cec1f6840 --- /dev/null +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -0,0 +1,320 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Union + +import partial_json_parser + + +def random_tool_call_id() -> str: + """Generate a random tool call ID""" + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, + ToolParserManager, +) +from fastdeploy.utils import data_processor_logger + + +@ToolParserManager.register_module("ernie_x1") +class ErnieX1ToolParser(ToolParser): + """ + Tool parser for Ernie model version 4.5.1. + This parser handles tool calls with newline formats. + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list + self.buffer: str = "" # buffer for accumulating unprocessed streaming content + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolCallParser constructor during construction." + ) + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + Supports XML-style formats with newlines: + - XML format: \n...\n\n\n\n\n{...}\n\n... + + Handles boundary cases: + 1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京" + 2. Only partial name: {"name": "get_we + 3. Only name and arguments field without content: {"name": "get_weather", "argume + """ + + try: + tool_calls = [] + + # Check for invalid tags before tool calls + if re.search(r"[\s\S]*?\s*(?=)", model_output): + data_processor_logger.error("Invalid format: tags found before ") + return ExtractedToolCallInformation(tools_called=False, content=model_output) + + function_call_arr = [] + remaining_text = model_output + + while True: + # 查找下一个tool_call块 + tool_call_pos = remaining_text.find("") + if tool_call_pos == -1: + break + + # 提取tool_call开始位置后的内容 + tool_content_start = tool_call_pos + len("") + tool_content_end = remaining_text.find("", tool_content_start) + + tool_json = "" + if tool_content_end == -1: + # 处理未闭合的tool_call块(截断情况) + tool_json = remaining_text[tool_content_start:].strip() + remaining_text = "" # 没有更多内容需要处理 + else: + # 处理完整的tool_call块 + tool_json = remaining_text[tool_content_start:tool_content_end].strip() + remaining_text = remaining_text[tool_content_end + len("") :] + + if not tool_json: + continue + + # 处理JSON内容 + tool_json = tool_json.strip() + if not tool_json.startswith("{"): + tool_json = "{" + tool_json + if not tool_json.endswith("}"): + tool_json = tool_json + "}" + + try: + # 首先尝试标准JSON解析 + try: + tool_data = json.loads(tool_json) + + if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data: + function_call_arr.append( + { + "name": tool_data["name"], + "arguments": tool_data["arguments"], + "_is_complete": True, # 明确标记为完整解析 + } + ) + continue + except json.JSONDecodeError: + pass + + # 标准解析失败时尝试partial_json_parser + from partial_json_parser.core.options import Allow + + try: + tool_data = {} + flags = Allow.ALL & ~Allow.STR + + # 解析name字段 + name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) + if name_match: + tool_data["name"] = name_match.group(1) + + # 解析arguments字段 + args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) + if args_match: + try: + tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags) + except: + tool_data["arguments"] = None + + if isinstance(tool_data, dict): + function_call_arr.append( + { + "name": tool_data.get("name", ""), + "arguments": tool_data.get("arguments", {}), + "_is_partial": True, # 标记为部分解析 + } + ) + except Exception as e: + data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") + continue + except Exception as e: + data_processor_logger.debug(f"Failed to parse tool call: {str(e)}") + continue + + if not function_call_arr: + data_processor_logger.error("No valid tool calls found") + return ExtractedToolCallInformation(tools_called=False, content=model_output) + + tool_calls = [] + all_complete = True # 初始设为True,只要有一个不完整就变为False + + for tool_call in function_call_arr: + # 记录工具调用解析状态 + is_complete = tool_call.get("_is_complete", False) + is_partial = tool_call.get("_is_partial", False) + + # 只要有一个不完整就认为整体不完整 + if not is_complete or is_partial: + all_complete = False + + # 处理参数序列化 + tool_args = tool_call.get("arguments", {}) + if not isinstance(tool_args, dict): + tool_args = {} + + try: + args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}" + except: + args_str = "{}" + + tool_calls.append( + ToolCall( + type="function", + id=random_tool_call_id(), + function=FunctionCall( + name=tool_call.get("name", ""), + arguments=args_str, + ), + ) + ) + + # 只有当所有工具调用都明确标记为complete时才返回tools_called=True + return ExtractedToolCallInformation( + tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" + ) + + except Exception as e: + data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}") + return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: dict, + ) -> Union[DeltaMessage, None]: + # 忽略空chunk + if len(delta_text.strip()) == 0: + return None + + try: + delta = None + # 使用buffer累积delta_text内容 + self.buffer += delta_text + + # 处理增量中的新tool_call开始 + if "" in delta_text and "" not in previous_text: + self.current_tool_id = ( + max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 + ) + self.current_tool_name_sent = False + if len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") + + # 增量解析逻辑 + + # 1. 尝试解析name字段 + if not self.current_tool_name_sent and '"name"' in self.buffer: + name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer) + if name_match: + name = name_match.group(1) + if name: + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=random_tool_call_id(), + function=DeltaFunctionCall(name=name).model_dump(exclude_none=True), + ) + ] + ) + print("delta name:", delta) + # 删除已处理的name部分 + self.buffer = self.buffer[name_match.end() :] + self.current_tool_name_sent = True + return delta + # 2. 尝试解析arguments字段 + if '"arguments"' in self.buffer: + args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer) + if args_match: + args_content = args_match.group(1) + # 处理多余的大括号 + open_braces = args_content.count("{") + close_braces = args_content.count("}") + if close_braces > open_braces: + args_content = args_content[: args_content.rfind("}")] + try: + # 增量解析arguments + parsed_args = json.loads(args_content) + if isinstance(parsed_args, dict): + args_json = json.dumps(parsed_args, ensure_ascii=False) + if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]): + argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :] + delta = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=argument_diff).model_dump( + exclude_none=True + ), + ) + ] + ) + print("delta argument:", delta) + # 删除已处理部分 + processed_pos = args_match.start() + len('"arguments":') + self.buffer = ( + self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :] + ) + self.streamed_args_for_tool[self.current_tool_id] = args_json + return delta + except Exception as e: + data_processor_logger.debug(f"Partial arguments parsing: {str(e)}") + + if "" in self.buffer: + end_pos = self.buffer.find("") + self.buffer = self.buffer[end_pos + len("") :] + + # 完成当前工具调用处理 + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + + return delta + + except Exception as e: + data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}") + return None diff --git a/fastdeploy/entrypoints/openai/tool_parsers/utils.py b/fastdeploy/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 000000000..b7dff3c58 --- /dev/null +++ b/fastdeploy/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,137 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import json +from json import JSONDecodeError, JSONDecoder +from typing import Any + +import partial_json_parser +from partial_json_parser.core.options import Allow + + +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = "" + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], "", 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, "", 1) + + return diff + + +def find_all_indices(string: str, substring: str) -> list[int]: + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices + + +# partial_json_parser doesn't support extra data and +# JSONDecoder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +def consume_space(i: int, s: str) -> int: + while i < len(s) and s[i].isspace(): + i += 1 + return i diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py index 7cbb847f7..f91293bc5 100644 --- a/fastdeploy/input/ernie_processor.py +++ b/fastdeploy/input/ernie_processor.py @@ -43,13 +43,14 @@ class ErnieProcessor(BaseDataProcessor): pad_token_id (int): 存储填充符号的token ID。 """ - def __init__(self, model_name_or_path, reasoning_parser_obj=None): + def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None): self.model_name_or_path = model_name_or_path data_processor_logger.info(f"model_name_or_path: {model_name_or_path}") self._init_config() self.decode_status = dict() + self.tool_parsers = dict() self.thinking_parser_dict = dict() self._load_tokenizer() data_processor_logger.info( @@ -63,6 +64,7 @@ class ErnieProcessor(BaseDataProcessor): self.reasoning_parser = None if reasoning_parser_obj: self.reasoning_parser = reasoning_parser_obj(self.tokenizer) + self.tool_parser_obj = tool_parser_obj def _init_config(self): self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1 @@ -204,6 +206,12 @@ class ErnieProcessor(BaseDataProcessor): response_dict.outputs.reasoning_content = reasoning_content else: response_dict.outputs.text = full_text + if self.tool_parser_obj: + tool_parser = self.tool_parser_obj(self.tokenizer) + tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) + if tool_call_info.tools_called: + response_dict.outputs.tool_calls = tool_call_info.tool_calls + response_dict.outputs.text = tool_call_info.content data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}") if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "": return None @@ -244,12 +252,20 @@ class ErnieProcessor(BaseDataProcessor): delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id) if is_end: full_text = previous_texts + delta_text - if enable_thinking and self.reasoning_parser: + if self.reasoning_parser and ( + enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" + ): reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict) response_dict["outputs"]["text"] = text response_dict["outputs"]["reasoning_content"] = reasoning_content else: response_dict["outputs"]["text"] = full_text + if self.tool_parser_obj: + tool_parser = self.tool_parser_obj(self.tokenizer) + tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) + if tool_call_info.tools_called: + response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls + response_dict["outputs"]["text"] = tool_call_info.content response_dict["outputs"]["raw_prediction"] = full_text data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] @@ -275,7 +291,9 @@ class ErnieProcessor(BaseDataProcessor): token_ids = token_ids[:-1] delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) response_dict["outputs"]["raw_prediction"] = delta_text - if enable_thinking and self.reasoning_parser: + if self.reasoning_parser and ( + enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" + ): reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming( previous_texts, previous_texts + delta_text, @@ -288,10 +306,25 @@ class ErnieProcessor(BaseDataProcessor): response_dict["outputs"]["reasoning_content"] = reasoning_content else: response_dict["outputs"]["text"] = delta_text - response_dict["outputs"]["raw_prediction"] = delta_text + if self.tool_parser_obj: + if req_id not in self.tool_parsers: + self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer) + tool_parser = self.tool_parsers[req_id] + tool_call = tool_parser.extract_tool_calls_streaming( + previous_texts, + previous_texts + delta_text, + delta_text, + previous_token_ids, + previous_token_ids + token_ids, + token_ids, + response_dict, + ) + response_dict["outputs"]["tool_delta_message"] = tool_call if is_end: data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] + if req_id in self.tool_parsers: + del self.tool_parsers[req_id] return response_dict def messages2ids(self, request_or_messages): diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py index d2975c697..21a96e92b 100644 --- a/fastdeploy/input/ernie_vl_processor.py +++ b/fastdeploy/input/ernie_vl_processor.py @@ -34,6 +34,7 @@ class ErnieMoEVLProcessor(ErnieProcessor): limit_mm_per_prompt=None, mm_processor_kwargs=None, reasoning_parser_obj=None, + tool_parser_obj=None, ): self.use_hf_tokenizer = False @@ -53,6 +54,7 @@ class ErnieMoEVLProcessor(ErnieProcessor): self.image_patch_id = self.ernie_processor.image_patch_id self.spatial_conv_size = self.ernie_processor.spatial_conv_size + self.tool_parsers = dict() self.decode_status = dict() self._load_tokenizer() self.eos_token_ids = [self.tokenizer.eos_token_id] @@ -62,6 +64,7 @@ class ErnieMoEVLProcessor(ErnieProcessor): self.reasoning_parser = None if reasoning_parser_obj: self.reasoning_parser = reasoning_parser_obj(self.tokenizer) + self.tool_parser_obj = tool_parser_obj # Generation config try: diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py index 8edd4eb4b..5c1e2e802 100644 --- a/fastdeploy/input/preprocess.py +++ b/fastdeploy/input/preprocess.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Optional from fastdeploy.config import ErnieArchitectures from fastdeploy.engine.config import ModelConfig +from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager from fastdeploy.reasoning import ReasoningParserManager @@ -48,6 +49,7 @@ class InputPreprocessor: limit_mm_per_prompt: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, enable_mm: bool = False, + tool_parser: str = None, ) -> None: self.model_name_or_path = model_name_or_path @@ -55,6 +57,7 @@ class InputPreprocessor: self.enable_mm = enable_mm self.limit_mm_per_prompt = limit_mm_per_prompt self.mm_processor_kwargs = mm_processor_kwargs + self.tool_parser = tool_parser def create_processor(self): """ @@ -68,8 +71,11 @@ class InputPreprocessor: DataProcessor or MultiModalRegistry.Processor (Union[DataProcessor, MultiModalRegistry.Processor]): 数据处理器。 """ reasoning_parser_obj = None + tool_parser_obj = None if self.reasoning_parser: reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser) + if self.tool_parser: + tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser) architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0] if not self.enable_mm: if not ErnieArchitectures.contains_ernie_arch(architectures): @@ -78,6 +84,7 @@ class InputPreprocessor: self.processor = DataProcessor( model_name_or_path=self.model_name_or_path, reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, ) else: from fastdeploy.input.ernie_processor import ErnieProcessor @@ -85,6 +92,7 @@ class InputPreprocessor: self.processor = ErnieProcessor( model_name_or_path=self.model_name_or_path, reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, ) else: if not architectures.startswith("Ernie4_5_VLMoeForConditionalGeneration"): @@ -97,5 +105,6 @@ class InputPreprocessor: limit_mm_per_prompt=self.limit_mm_per_prompt, mm_processor_kwargs=self.mm_processor_kwargs, reasoning_parser_obj=reasoning_parser_obj, + tool_parser_obj=tool_parser_obj, ) return self.processor diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index eec346341..4bffee280 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -148,7 +148,7 @@ class BaseDataProcessor(ABC): class DataProcessor(BaseDataProcessor): - def __init__(self, model_name_or_path, reasoning_parser_obj=None): + def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None): """ Initializes the DecodeStatus object. @@ -168,6 +168,7 @@ class DataProcessor(BaseDataProcessor): self._init_config() self.decode_status = dict() + self.tool_parsers = dict() self.tokenizer = self._load_tokenizer() data_processor_logger.info( f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \ @@ -180,6 +181,7 @@ class DataProcessor(BaseDataProcessor): self.eos_token_id_len = len(self.eos_token_ids) self.pad_token_id = self.get_pad_id() self.reasoning_parser = None + self.tool_parser_obj = tool_parser_obj if reasoning_parser_obj: self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.tokenizer.pad_token_id = self.pad_token_id @@ -329,6 +331,12 @@ class DataProcessor(BaseDataProcessor): else: # 模型不支持思考,并且没单独设置enable_thinking为false response_dict.outputs.text = full_text + if self.tool_parser_obj: + tool_parser = self.tool_parser_obj(self.tokenizer) + tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) + if tool_call_info.tools_called: + response_dict.outputs.tool_calls = tool_call_info.tool_calls + response_dict.outputs.text = tool_call_info.content data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}") return response_dict @@ -360,6 +368,12 @@ class DataProcessor(BaseDataProcessor): response_dict["outputs"]["reasoning_content"] = reasoning_content else: response_dict["outputs"]["text"] = full_text + if self.tool_parser_obj: + tool_parser = self.tool_parser_obj(self.tokenizer) + tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict) + if tool_call_info.tools_called: + response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls + response_dict["outputs"]["text"] = tool_call_info.content data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] return response_dict @@ -397,9 +411,25 @@ class DataProcessor(BaseDataProcessor): response_dict["outputs"]["reasoning_content"] = reasoning_content else: response_dict["outputs"]["text"] = delta_text + if self.tool_parser_obj and not is_end: + if req_id not in self.tool_parsers: + self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer) + tool_parser = self.tool_parsers[req_id] + tool_call = tool_parser.extract_tool_calls_streaming( + previous_texts, + previous_texts + delta_text, + delta_text, + previous_token_ids, + previous_token_ids + token_ids, + token_ids, + response_dict, + ) + response_dict["outputs"]["tool_delta_message"] = tool_call if is_end: data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") del self.decode_status[req_id] + if req_id in self.tool_parsers: + del self.tool_parsers[req_id] return response_dict def process_response_dict(self, response_dict, **kwargs): diff --git a/fastdeploy/reasoning/__init__.py b/fastdeploy/reasoning/__init__.py index aa7d65e50..51f59776e 100644 --- a/fastdeploy/reasoning/__init__.py +++ b/fastdeploy/reasoning/__init__.py @@ -16,6 +16,7 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser +from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser from .qwen3_reasoning_parsers import Qwen3ReasoningParser __all__ = [ @@ -23,4 +24,5 @@ __all__ = [ "ReasoningParserManager", "ErnieVLReasoningParser", "Qwen3ReasoningParser", + "ErnieX1ReasoningParser", ] diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py new file mode 100644 index 000000000..458505252 --- /dev/null +++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py @@ -0,0 +1,208 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# +from collections.abc import Sequence +from typing import Tuple + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest +from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager + +# +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +@ReasoningParserManager.register_module("ernie_x1") +class ErnieX1ReasoningParser(ReasoningParser): + """ + Reasoning parser for ernie_x1 model with stricter boundary checking. + + This implementation follows the user's proposed approach: + 1. For thinking content: waits for \n then checks for tag + 2. For response content: checks for tag first, then waits for \n + 3. Handles newlines in content more precisely + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + self.think_end_token = "" + self.response_start_token = "" + self.response_end_token = "" + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + if not self.model_tokenizer: + raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.") + + self.think_end_token_id = self.vocab.get("") + if self.think_end_token_id is None: + raise RuntimeError("Could not find think end token id in tokenizer vocabulary") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> tuple[str, str]: + """ + 根据用户需求实现的流式解析方法: + 1. 初始内容都视为思考内容 + 2. 当遇到\n时检查后续是否是 + 3. 思考结束后检查是还是 + 4. 对于内容,处理换行和结束标记 + """ + # 如果还在思考阶段 + if not previous_text.endswith(self.think_end_token): + # 如果遇到\n后接或直接遇到,思考结束 + if (previous_text.endswith("\n") and delta_text == self.think_end_token) or ( + not previous_text.endswith("\n") and delta_text == self.think_end_token + ): + return "", "" + # 否则继续返回思考内容 + return delta_text, "" + + # 思考结束后检查是tool_call还是response + remaining_text = previous_text + delta_text + after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :] + + # 跳过think后的换行 + after_think = after_think.lstrip("\n") + + # 处理tool_call情况 + if after_think.startswith(self.tool_call_start_token): + return "", "" + + # 处理response情况 + if after_think.startswith(self.response_start_token): + response_content = after_think[len(self.response_start_token) :] + # 跳过response后的换行 + response_content = response_content.lstrip("\n") + + # 检查response是否结束 + if response_content.endswith(self.response_end_token): + return "", "" + + # 返回response内容(使用delta_text确保流式输出) + return "", delta_text + + # 默认情况不返回内容 + return "", "" + + def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]: + """ + Batch version of the enhanced parser. + Modified to preserve newlines in both reasoning and response content, + only removing the single newline before closing tags. + """ + reasoning_content = "" + response_content = "" + + think_end_pos = model_output.find(self.think_end_token) + if think_end_pos != -1: + # Extract thinking content - only remove the last newline before + reasoning_content = model_output[:think_end_pos] + if think_end_pos > 0 and reasoning_content[-1] == "\n": + reasoning_content = reasoning_content[:-1] + + remaining = model_output[think_end_pos + len(self.think_end_token) :] + + # Skip newlines after + remaining = remaining.lstrip("\n") + + # Check for response or tool_call + if remaining.startswith(self.response_start_token): + response_pos = len(self.response_start_token) + remaining = remaining[response_pos:].lstrip("\n") + response_end_pos = remaining.find(self.response_end_token) + if response_end_pos != -1: + # Only strip the last newline before , not all + if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n": + response_content = remaining[: response_end_pos - 1] + else: + response_content = remaining[:response_end_pos] + else: + # If no found, return the rest as response content + response_content = remaining + elif remaining.startswith(self.tool_call_start_token): + pass # No response content + else: + # No thinking content found, return the whole input as reasoning + reasoning_content = model_output + response_content = "" + return reasoning_content, response_content + + +import unittest +from unittest.mock import MagicMock + + +class TestErnieX1ReasoningParser(unittest.TestCase): + def setUp(self): + self.tokenizer = MagicMock() + self.tokenizer.vocab = { + "\n\n\n": 1001, + "\n": 1002, + "\n\n": 1003, + "\n": 1004, + "\n\n": 1005, + } + self.parser = ErnieX1ReasoningParser(self.tokenizer) + + def test_streaming_with_think_and_response(self): + # 测试标准情况:\n\n\n\ncontent\n\n + prev_text = "thinking" + delta_text = "\n\n\n\nanswer\n\n" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], []) + self.assertEqual(result, ("thinking", "answer")) + + def test_streaming_with_think_and_tool_call(self): + # 测试tool_call情况 + prev_text = "thinking" + delta_text = "\n\n\n\ndetails\n\n" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], []) + self.assertEqual(result, ("thinking", "")) + + def test_streaming_with_think_no_newline(self): + # 测试没有前置换行的情况 + prev_text = "thinking" + delta_text = "\n\nanswer\n" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], []) + self.assertEqual(result, ("thinking", "answer")) + + def test_streaming_response_without_leading_newline(self): + # 测试response内容没有前置换行 + prev_text = "thinking\n\n\n" + delta_text = "answer\n\n" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], []) + self.assertEqual(result, ("thinking", "answer")) + + def test_streaming_response_with_middle_newline(self): + # 测试response内容中间的换行符 + prev_text = "thinking\n\n\n\n" + delta_text = "line1\nline2\n\n" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], []) + self.assertEqual(result, ("thinking", "line1\nline2")) + + def test_streaming_partial_response(self): + # 测试不完整的response流式输出 + prev_text = "thinking\n\n\n\n" + delta_text = "partial answer" + result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], []) + self.assertEqual(result, ("thinking", "partial answer")) + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index 5d68c7681..70e5df129 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -23,6 +23,7 @@ import os import random import re import socket +import sys import tarfile import time from datetime import datetime @@ -591,6 +592,22 @@ def is_list_of( assert_never(check) +def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): + """ + Import a Python file according to its file path. + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ModuleNotFoundError(f"No module named '{module_name}'") + + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + def version(): """ Prints the contents of the version.txt file located in the parent directory of this script. diff --git a/requirements.txt b/requirements.txt index 55489db3a..0e0d5ca6f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql opentelemetry-distro  opentelemetry-exporter-otlp opentelemetry-instrumentation-fastapi +partial_json_parser