diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py
index c254aaa1a..834c8096b 100644
--- a/fastdeploy/engine/args_utils.py
+++ b/fastdeploy/engine/args_utils.py
@@ -15,10 +15,10 @@
"""
import json
+import os
from dataclasses import asdict, dataclass
from dataclasses import fields as dataclass_fields
from typing import Any, Dict, List, Optional
-import os
from fastdeploy.config import (
CacheConfig,
@@ -93,6 +93,14 @@ class EngineArgs:
"""
specifies the reasoning parser to use for extracting reasoning content from the model output
"""
+ tool_call_parser: str = None
+ """
+ specifies the tool call parser to use for extracting tool call from the model output
+ """
+ tool_parser_plugin: str = None
+ """
+ tool parser plugin used to register user defined tool parsers
+ """
enable_mm: bool = False
"""
Flags to enable multi-modal model
@@ -421,6 +429,18 @@ class EngineArgs:
help="Flag specifies the reasoning parser to use for extracting "
"reasoning content from the model output",
)
+ model_group.add_argument(
+ "--tool-call-parser",
+ type=str,
+ default=EngineArgs.tool_call_parser,
+ help="Flag specifies the tool call parser to use for extracting" "tool call from the model output",
+ )
+ model_group.add_argument(
+ "--tool-parser-plugin",
+ type=str,
+ default=EngineArgs.tool_parser_plugin,
+ help="tool parser plugin used to register user defined tool parsers",
+ )
model_group.add_argument(
"--speculative-config",
type=json.loads,
@@ -866,10 +886,10 @@ class EngineArgs:
if self.enable_chunked_prefill:
self.max_num_batched_tokens = 2048
else:
- if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
+ if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
- self.max_num_batched_tokens = 8192
+ self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
all_dict = asdict(self)
all_dict["model_cfg"] = model_cfg
@@ -908,6 +928,7 @@ class EngineArgs:
mm_processor_kwargs=self.mm_processor_kwargs,
enable_mm=self.enable_mm,
reasoning_parser=self.reasoning_parser,
+ tool_parser=self.tool_call_parser,
splitwise_role=self.splitwise_role,
innode_prefill_ports=self.innode_prefill_ports,
max_num_partial_prefills=self.max_num_partial_prefills,
diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py
index fb57884bf..31f7c5c70 100644
--- a/fastdeploy/engine/config.py
+++ b/fastdeploy/engine/config.py
@@ -85,6 +85,7 @@ class Config:
max_long_partial_prefills: int = 1,
long_prefill_token_threshold: int = 0,
reasoning_parser: str = None,
+ tool_parser: str = None,
guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False,
enable_logprob: bool = False,
@@ -165,6 +166,7 @@ class Config:
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold
self.reasoning_parser = reasoning_parser
+ self.tool_parser = tool_parser
self.graph_optimization_config = graph_optimization_config
self.early_stop_config = early_stop_config
self.guided_decoding_backend = guided_decoding_backend
@@ -236,10 +238,10 @@ class Config:
if self.cache_config.enable_chunked_prefill:
self.max_num_batched_tokens = 2048
else:
- if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
+ if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
self.max_num_batched_tokens = self.max_model_len
else:
- self.max_num_batched_tokens = 8192
+ self.max_num_batched_tokens = 8192 # if set to max_model_len, it's easy to be OOM
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -287,7 +289,7 @@ class Config:
)
if not self.cache_config.enable_chunked_prefill:
- if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
+ if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
assert self.max_num_batched_tokens >= self.max_model_len, (
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
f"should be larger than or equal to max_model_len: {self.max_model_len}"
diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py
index 88067ed06..5ac7e6ab6 100644
--- a/fastdeploy/engine/engine.py
+++ b/fastdeploy/engine/engine.py
@@ -106,6 +106,7 @@ class LLMEngine:
cfg.limit_mm_per_prompt,
cfg.mm_processor_kwargs,
cfg.enable_mm,
+ cfg.tool_parser,
)
self.start_queue_service()
diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py
index acf717547..b9fa895e6 100644
--- a/fastdeploy/engine/request.py
+++ b/fastdeploy/engine/request.py
@@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union
import numpy as np
from fastdeploy.engine.sampling_params import SamplingParams
+from fastdeploy.entrypoints.openai.protocol import ToolCall
from fastdeploy.utils import data_processor_logger
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
@@ -249,6 +250,7 @@ class CompletionOutput:
draft_token_ids: list[int] = None
text: Optional[str] = None
reasoning_content: Optional[str] = None
+ tool_calls: Optional[ToolCall] = None
def to_dict(self):
"""
diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py
index 4f7357e11..059ecee01 100644
--- a/fastdeploy/entrypoints/chat_utils.py
+++ b/fastdeploy/entrypoints/chat_utils.py
@@ -14,6 +14,7 @@
# limitations under the License.
"""
+import uuid
from copy import deepcopy
from typing import List, Literal, Union
from urllib.parse import urlparse
@@ -156,3 +157,7 @@ def parse_chat_messages(messages):
conversation.append({"role": role, "content": parsed_content})
return conversation
+
+
+def random_tool_call_id() -> str:
+ return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py
index 7c92390bc..a1c787162 100644
--- a/fastdeploy/entrypoints/engine_client.py
+++ b/fastdeploy/entrypoints/engine_client.py
@@ -45,6 +45,7 @@ class EngineClient:
data_parallel_size=1,
enable_logprob=False,
workers=1,
+ tool_parser=None,
):
input_processor = InputPreprocessor(
tokenizer,
@@ -52,6 +53,7 @@ class EngineClient:
limit_mm_per_prompt,
mm_processor_kwargs,
enable_mm,
+ tool_parser,
)
self.enable_logprob = enable_logprob
self.enable_mm = enable_mm
diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py
index 3e150abf2..8291c974a 100644
--- a/fastdeploy/entrypoints/llm.py
+++ b/fastdeploy/entrypoints/llm.py
@@ -28,8 +28,7 @@ from tqdm import tqdm
from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.engine.engine import LLMEngine
from fastdeploy.engine.sampling_params import SamplingParams
-
-# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
+from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.utils import llm_logger, retrive_model_from_server
from fastdeploy.worker.output import Logprob, LogprobsLists
@@ -73,6 +72,9 @@ class LLM:
**kwargs,
):
model = retrive_model_from_server(model, revision)
+ tool_parser_plugin = kwargs.get("tool_parser_plugin")
+ if tool_parser_plugin:
+ ToolParserManager.import_tool_parser(tool_parser_plugin)
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py
index 6562bfac3..dd2c1fb69 100644
--- a/fastdeploy/entrypoints/openai/api_server.py
+++ b/fastdeploy/entrypoints/openai/api_server.py
@@ -41,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import (
)
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
+from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.metrics.metrics import (
EXCLUDE_LABELS,
cleanup_prometheus_files,
@@ -73,7 +74,8 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision)
-
+if args.tool_parser_plugin:
+ ToolParserManager.import_tool_parser(args.tool_parser_plugin)
llm_engine = None
@@ -126,6 +128,7 @@ async def lifespan(app: FastAPI):
args.data_parallel_size,
args.enable_logprob,
args.workers,
+ args.tool_call_parser,
)
app.state.dynamic_load_weight = args.dynamic_load_weight
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py
index 678ae8dd0..2049fb971 100644
--- a/fastdeploy/entrypoints/openai/protocol.py
+++ b/fastdeploy/entrypoints/openai/protocol.py
@@ -72,7 +72,6 @@ class ToolCall(BaseModel):
id: str = None
type: Literal["function"] = "function"
function: FunctionCall
- index: int
class DeltaFunctionCall(BaseModel):
@@ -96,6 +95,18 @@ class DeltaToolCall(BaseModel):
function: Optional[DeltaFunctionCall] = None
+class ExtractedToolCallInformation(BaseModel):
+ # indicate if tools were called
+ tools_called: bool
+
+ # extracted tool calls
+ tool_calls: Optional[list[ToolCall]] = None
+
+ # content - per OpenAI spec, content AND tool calls can be returned rarely
+ # But some models will do this intentionally
+ content: Optional[str] = None
+
+
class FunctionDefinition(BaseModel):
"""
Function definition.
diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py
index 829f39f3d..0ee0a3423 100644
--- a/fastdeploy/entrypoints/openai/serving_chat.py
+++ b/fastdeploy/entrypoints/openai/serving_chat.py
@@ -141,6 +141,7 @@ class OpenAIServingChat:
previous_num_tokens = 0
num_prompt_tokens = 0
num_choices = 1
+ tool_called = False
max_streaming_response_tokens = (
request.max_streaming_response_tokens
if request.max_streaming_response_tokens is not None
@@ -244,20 +245,28 @@ class OpenAIServingChat:
output = res["outputs"]
delta_text = output["text"]
output_top_logprobs = output["top_logprobs"]
+ previous_num_tokens += len(output["token_ids"])
logprobs_res: Optional[LogProbs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
)
-
- previous_num_tokens += len(output["token_ids"])
- delta_message = DeltaMessage(
- content=delta_text,
- reasoning_content=output.get("reasoning_content"),
- prompt_token_ids=None,
- completion_token_ids=None,
- tool_calls=output.get("tool_call_content", []),
- )
+ if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
+ tool_delta_message = output["tool_delta_message"]
+ if tool_delta_message is None:
+ continue
+ delta_message = tool_delta_message
+ delta_message.reasoning_content = output.get("reasoning_content")
+ if delta_message.tool_calls:
+ tool_called = True
+ else:
+ delta_message = DeltaMessage(
+ content=delta_text,
+ reasoning_content=output.get("reasoning_content"),
+ prompt_token_ids=None,
+ completion_token_ids=None,
+ tool_calls=None,
+ )
choice = ChatCompletionResponseStreamChoice(
index=0,
@@ -274,10 +283,7 @@ class OpenAIServingChat:
max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop"
- if (
- self.engine_client.reasoning_parser == "ernie_x1"
- and output.get("finish_reason", "") == "tool_calls"
- ):
+ if tool_called:
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
@@ -414,7 +420,7 @@ class OpenAIServingChat:
role="assistant",
content=output["text"],
reasoning_content=output.get("reasoning_content"),
- tool_calls=output.get("tool_call_content"),
+ tool_calls=output.get("tool_call"),
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
completion_token_ids=completion_token_ids if request.return_token_ids else None,
text_after_process=text_after_process if request.return_token_ids else None,
@@ -434,7 +440,7 @@ class OpenAIServingChat:
max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop"
- if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
+ if output.get("tool_call"):
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py
index b68407942..bec869699 100644
--- a/fastdeploy/entrypoints/openai/serving_completion.py
+++ b/fastdeploy/entrypoints/openai/serving_completion.py
@@ -240,9 +240,9 @@ class OpenAIServingCompletion:
dealer.close()
self.engine_client.semaphore.release()
- def calc_finish_reason(self, max_tokens, token_num, output):
+ def calc_finish_reason(self, max_tokens, token_num, output, tool_called):
if max_tokens is None or token_num != max_tokens:
- if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
+ if tool_called or output.get("tool_call"):
return "tool_calls"
else:
return "stop"
@@ -271,6 +271,7 @@ class OpenAIServingCompletion:
output_tokens = [0] * num_choices
inference_start_time = [0] * num_choices
first_iteration = [True] * num_choices
+ tool_called = False
max_streaming_response_tokens = (
request.max_streaming_response_tokens
if request.max_streaming_response_tokens is not None
@@ -339,24 +340,41 @@ class OpenAIServingCompletion:
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
- choices.append(
- CompletionResponseStreamChoice(
+ output_tokens[idx] += 1
+ if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
+ tool_delta_message = output["tool_delta_message"]
+ if tool_delta_message is None:
+ continue
+ delta_message = CompletionResponseStreamChoice(
index=idx,
text=output["text"],
- prompt_token_ids=None,
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
- raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
- tool_calls=output.get("tool_call_content"),
+ tool_calls=tool_delta_message.tool_calls,
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time,
logprobs=logprobs_res,
)
- )
+ if tool_delta_message.tool_calls:
+ tool_called = True
+ else:
+ delta_message = CompletionResponseStreamChoice(
+ index=idx,
+ text=output["text"],
+ prompt_token_ids=None,
+ completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
+ tool_calls=None,
+ raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
+ reasoning_content=output.get("reasoning_content"),
+ arrival_time=arrival_time,
+ logprobs=logprobs_res,
+ )
+
+ choices.append(delta_message)
output_tokens[idx] += 1
if res["finished"]:
choices[-1].finish_reason = self.calc_finish_reason(
- request.max_tokens, output_tokens[idx], output
+ request.max_tokens, output_tokens[idx], output, tool_called
)
if len(choices) == max_streaming_response_tokens or res["finished"]:
@@ -445,7 +463,7 @@ class OpenAIServingCompletion:
token_ids = output["token_ids"]
output_text = output["text"]
- finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output)
+ finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False)
choice_data = CompletionResponseChoice(
token_ids=token_ids,
@@ -456,7 +474,7 @@ class OpenAIServingCompletion:
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
text_after_process=text_after_process_list[idx] if request.return_token_ids else None,
reasoning_content=output.get("reasoning_content"),
- tool_calls=output.get("tool_call_content"),
+ tool_calls=output.get("tool_call"),
logprobs=aggregated_logprobs,
finish_reason=finish_reason,
)
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/__init__.py b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py
new file mode 100644
index 000000000..2078a8c9f
--- /dev/null
+++ b/fastdeploy/entrypoints/openai/tool_parsers/__init__.py
@@ -0,0 +1,24 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+from .abstract_tool_parser import ToolParser, ToolParserManager
+from .ernie_x1_tool_parser import ErnieX1ToolParser
+
+__all__ = [
+ "ToolParser",
+ "ToolParserManager",
+ "ErnieX1ToolParser",
+]
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
new file mode 100644
index 000000000..d6ac8f81a
--- /dev/null
+++ b/fastdeploy/entrypoints/openai/tool_parsers/abstract_tool_parser.py
@@ -0,0 +1,159 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import os
+from collections.abc import Sequence
+from functools import cached_property
+from typing import Callable, Optional, Union
+
+from fastdeploy.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ DeltaMessage,
+ ExtractedToolCallInformation,
+)
+from fastdeploy.utils import data_processor_logger, import_from_path, is_list_of
+
+
+class ToolParser:
+ """
+ Abstract ToolParser class that should not be used directly. Provided
+ properties and methods should be used in
+ derived classes.
+ """
+
+ def __init__(self, tokenizer):
+ self.prev_tool_call_arr: list[dict] = []
+ # the index of the tool call that is currently being parsed
+ self.current_tool_id: int = -1
+ self.current_tool_name_sent: bool = False
+ self.streamed_args_for_tool: list[str] = []
+
+ self.model_tokenizer = tokenizer
+
+ @cached_property
+ def vocab(self) -> dict[str, int]:
+ # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
+ # whereas all tokenizers have .get_vocab()
+ return self.model_tokenizer.get_vocab()
+
+ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+ """
+ Static method that used to adjust the request parameters.
+ """
+ return request
+
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ """
+ Static method that should be implemented for extracting tool calls from
+ a complete model-generated string.
+ Used for non-streaming responses where we have the entire model response
+ available before sending to the client.
+ Static because it's stateless.
+ """
+ raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!")
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+ """
+ Instance method that should be implemented for extracting tool calls
+ from an incomplete response; for use when handling tool calls and
+ streaming. Has to be an instance method because it requires state -
+ the current tokens/diffs, but also the information about what has
+ previously been parsed and extracted (see constructor)
+ """
+ raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!")
+
+
+class ToolParserManager:
+ tool_parsers: dict[str, type] = {}
+
+ @classmethod
+ def get_tool_parser(cls, name) -> type:
+ """
+ Get tool parser by name which is registered by `register_module`.
+
+ Raise a KeyError exception if the name is not registered.
+ """
+ if name in cls.tool_parsers:
+ return cls.tool_parsers[name]
+
+ raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
+
+ @classmethod
+ def _register_module(
+ cls, module: type, module_name: Optional[Union[str, list[str]]] = None, force: bool = True
+ ) -> None:
+ if not issubclass(module, ToolParser):
+ raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}")
+ if module_name is None:
+ module_name = module.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in cls.tool_parsers:
+ existed_module = cls.tool_parsers[name]
+ raise KeyError(f"{name} is already registered " f"at {existed_module.__module__}")
+ cls.tool_parsers[name] = module
+
+ @classmethod
+ def register_module(
+ cls, name: Optional[Union[str, list[str]]] = None, force: bool = True, module: Union[type, None] = None
+ ) -> Union[type, Callable]:
+ """
+ Register module with the given name or name list. it can be used as a
+ decoder(with module as None) or normal function(with module as not
+ None).
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
+
+ # raise the error ahead of time
+ if not (name is None or isinstance(name, str) or is_list_of(name, str)):
+ raise TypeError("name must be None, an instance of str, or a sequence of str, " f"but got {type(name)}")
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ cls._register_module(module=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(module):
+ cls._register_module(module=module, module_name=name, force=force)
+ return module
+
+ return _register
+
+ @classmethod
+ def import_tool_parser(cls, plugin_path: str) -> None:
+ """
+ Import a user-defined tool parser by the path of the tool parser define
+ file.
+ """
+ module_name = os.path.splitext(os.path.basename(plugin_path))[0]
+
+ try:
+ import_from_path(module_name, plugin_path)
+ except Exception:
+ data_processor_logger.exception("Failed to load module '%s' from %s.", module_name, plugin_path)
+ return
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
new file mode 100644
index 000000000..cec1f6840
--- /dev/null
+++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+import uuid
+from collections.abc import Sequence
+from typing import Union
+
+import partial_json_parser
+
+
+def random_tool_call_id() -> str:
+ """Generate a random tool call ID"""
+ return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+
+
+from fastdeploy.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser,
+ ToolParserManager,
+)
+from fastdeploy.utils import data_processor_logger
+
+
+@ToolParserManager.register_module("ernie_x1")
+class ErnieX1ToolParser(ToolParser):
+ """
+ Tool parser for Ernie model version 4.5.1.
+ This parser handles tool calls with newline formats.
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id: int = -1
+ self.current_tool_name_sent: bool = False
+ self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
+ self.buffer: str = "" # buffer for accumulating unprocessed streaming content
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolCallParser constructor during construction."
+ )
+
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+ """
+ Extract the tool calls from a complete model response.
+ Supports XML-style formats with newlines:
+ - XML format: \n...\n\n\n\n\n{...}\n\n...
+
+ Handles boundary cases:
+ 1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
+ 2. Only partial name: {"name": "get_we
+ 3. Only name and arguments field without content: {"name": "get_weather", "argume
+ """
+
+ try:
+ tool_calls = []
+
+ # Check for invalid tags before tool calls
+ if re.search(r"[\s\S]*?\s*(?=)", model_output):
+ data_processor_logger.error("Invalid format: tags found before ")
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
+
+ function_call_arr = []
+ remaining_text = model_output
+
+ while True:
+ # 查找下一个tool_call块
+ tool_call_pos = remaining_text.find("")
+ if tool_call_pos == -1:
+ break
+
+ # 提取tool_call开始位置后的内容
+ tool_content_start = tool_call_pos + len("")
+ tool_content_end = remaining_text.find("", tool_content_start)
+
+ tool_json = ""
+ if tool_content_end == -1:
+ # 处理未闭合的tool_call块(截断情况)
+ tool_json = remaining_text[tool_content_start:].strip()
+ remaining_text = "" # 没有更多内容需要处理
+ else:
+ # 处理完整的tool_call块
+ tool_json = remaining_text[tool_content_start:tool_content_end].strip()
+ remaining_text = remaining_text[tool_content_end + len("") :]
+
+ if not tool_json:
+ continue
+
+ # 处理JSON内容
+ tool_json = tool_json.strip()
+ if not tool_json.startswith("{"):
+ tool_json = "{" + tool_json
+ if not tool_json.endswith("}"):
+ tool_json = tool_json + "}"
+
+ try:
+ # 首先尝试标准JSON解析
+ try:
+ tool_data = json.loads(tool_json)
+
+ if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
+ function_call_arr.append(
+ {
+ "name": tool_data["name"],
+ "arguments": tool_data["arguments"],
+ "_is_complete": True, # 明确标记为完整解析
+ }
+ )
+ continue
+ except json.JSONDecodeError:
+ pass
+
+ # 标准解析失败时尝试partial_json_parser
+ from partial_json_parser.core.options import Allow
+
+ try:
+ tool_data = {}
+ flags = Allow.ALL & ~Allow.STR
+
+ # 解析name字段
+ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
+ if name_match:
+ tool_data["name"] = name_match.group(1)
+
+ # 解析arguments字段
+ args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
+ if args_match:
+ try:
+ tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
+ except:
+ tool_data["arguments"] = None
+
+ if isinstance(tool_data, dict):
+ function_call_arr.append(
+ {
+ "name": tool_data.get("name", ""),
+ "arguments": tool_data.get("arguments", {}),
+ "_is_partial": True, # 标记为部分解析
+ }
+ )
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+ except Exception as e:
+ data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
+ continue
+
+ if not function_call_arr:
+ data_processor_logger.error("No valid tool calls found")
+ return ExtractedToolCallInformation(tools_called=False, content=model_output)
+
+ tool_calls = []
+ all_complete = True # 初始设为True,只要有一个不完整就变为False
+
+ for tool_call in function_call_arr:
+ # 记录工具调用解析状态
+ is_complete = tool_call.get("_is_complete", False)
+ is_partial = tool_call.get("_is_partial", False)
+
+ # 只要有一个不完整就认为整体不完整
+ if not is_complete or is_partial:
+ all_complete = False
+
+ # 处理参数序列化
+ tool_args = tool_call.get("arguments", {})
+ if not isinstance(tool_args, dict):
+ tool_args = {}
+
+ try:
+ args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
+ except:
+ args_str = "{}"
+
+ tool_calls.append(
+ ToolCall(
+ type="function",
+ id=random_tool_call_id(),
+ function=FunctionCall(
+ name=tool_call.get("name", ""),
+ arguments=args_str,
+ ),
+ )
+ )
+
+ # 只有当所有工具调用都明确标记为complete时才返回tools_called=True
+ return ExtractedToolCallInformation(
+ tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
+ )
+
+ except Exception as e:
+ data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
+ return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: dict,
+ ) -> Union[DeltaMessage, None]:
+ # 忽略空chunk
+ if len(delta_text.strip()) == 0:
+ return None
+
+ try:
+ delta = None
+ # 使用buffer累积delta_text内容
+ self.buffer += delta_text
+
+ # 处理增量中的新tool_call开始
+ if "" in delta_text and "" not in previous_text:
+ self.current_tool_id = (
+ max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
+ )
+ self.current_tool_name_sent = False
+ if len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+ data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
+
+ # 增量解析逻辑
+
+ # 1. 尝试解析name字段
+ if not self.current_tool_name_sent and '"name"' in self.buffer:
+ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
+ if name_match:
+ name = name_match.group(1)
+ if name:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ type="function",
+ id=random_tool_call_id(),
+ function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
+ )
+ ]
+ )
+ print("delta name:", delta)
+ # 删除已处理的name部分
+ self.buffer = self.buffer[name_match.end() :]
+ self.current_tool_name_sent = True
+ return delta
+ # 2. 尝试解析arguments字段
+ if '"arguments"' in self.buffer:
+ args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
+ if args_match:
+ args_content = args_match.group(1)
+ # 处理多余的大括号
+ open_braces = args_content.count("{")
+ close_braces = args_content.count("}")
+ if close_braces > open_braces:
+ args_content = args_content[: args_content.rfind("}")]
+ try:
+ # 增量解析arguments
+ parsed_args = json.loads(args_content)
+ if isinstance(parsed_args, dict):
+ args_json = json.dumps(parsed_args, ensure_ascii=False)
+ if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
+ argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(arguments=argument_diff).model_dump(
+ exclude_none=True
+ ),
+ )
+ ]
+ )
+ print("delta argument:", delta)
+ # 删除已处理部分
+ processed_pos = args_match.start() + len('"arguments":')
+ self.buffer = (
+ self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
+ )
+ self.streamed_args_for_tool[self.current_tool_id] = args_json
+ return delta
+ except Exception as e:
+ data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
+
+ if "" in self.buffer:
+ end_pos = self.buffer.find("")
+ self.buffer = self.buffer[end_pos + len("") :]
+
+ # 完成当前工具调用处理
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+
+ return delta
+
+ except Exception as e:
+ data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
+ return None
diff --git a/fastdeploy/entrypoints/openai/tool_parsers/utils.py b/fastdeploy/entrypoints/openai/tool_parsers/utils.py
new file mode 100644
index 000000000..b7dff3c58
--- /dev/null
+++ b/fastdeploy/entrypoints/openai/tool_parsers/utils.py
@@ -0,0 +1,137 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+import json
+from json import JSONDecodeError, JSONDecoder
+from typing import Any
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+
+def find_common_prefix(s1: str, s2: str) -> str:
+ """
+ Finds a common prefix that is shared between two strings, if there is one.
+ Order of arguments is NOT important.
+
+ This function is provided as a UTILITY for extracting information from JSON
+ generated by partial_json_parser, to help in ensuring that the right tokens
+ are returned in streaming, so that close-quotes, close-brackets and
+ close-braces are not returned prematurely.
+
+ e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
+ '{"fruit": "ap'
+ """
+ prefix = ""
+ min_length = min(len(s1), len(s2))
+ for i in range(0, min_length):
+ if s1[i] == s2[i]:
+ prefix += s1[i]
+ else:
+ break
+ return prefix
+
+
+def find_common_suffix(s1: str, s2: str) -> str:
+ """
+ Finds a common suffix shared between two strings, if there is one. Order of
+ arguments is NOT important.
+ Stops when the suffix ends OR it hits an alphanumeric character
+
+ e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
+ """
+ suffix = ""
+ min_length = min(len(s1), len(s2))
+ for i in range(1, min_length + 1):
+ if s1[-i] == s2[-i] and not s1[-i].isalnum():
+ suffix = s1[-i] + suffix
+ else:
+ break
+ return suffix
+
+
+def extract_intermediate_diff(curr: str, old: str) -> str:
+ """
+ Given two strings, extract the difference in the middle between two strings
+ that are known to have a common prefix and/or suffix.
+
+ This function is provided as a UTILITY for extracting information from JSON
+ generated by partial_json_parser, to help in ensuring that the right tokens
+ are returned in streaming, so that close-quotes, close-brackets and
+ close-braces are not returned prematurely. The order of arguments IS
+ important - the new version of the partially-parsed JSON must be the first
+ argument, and the secnod argument must be from the previous generation.
+
+ What it returns, is tokens that should be streamed to the client.
+
+ e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
+ -> 'ple'
+
+ """
+ suffix = find_common_suffix(curr, old)
+
+ old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
+ prefix = find_common_prefix(curr, old)
+ diff = curr
+ if len(suffix):
+ diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
+
+ if len(prefix):
+ # replace the prefix only once in case it's mirrored
+ diff = diff.replace(prefix, "", 1)
+
+ return diff
+
+
+def find_all_indices(string: str, substring: str) -> list[int]:
+ """
+ Find all (starting) indices of a substring in a given string. Useful for
+ tool call extraction
+ """
+ indices = []
+ index = -1
+ while True:
+ index = string.find(substring, index + 1)
+ if index == -1:
+ break
+ indices.append(index)
+ return indices
+
+
+# partial_json_parser doesn't support extra data and
+# JSONDecoder.raw_decode doesn't support partial JSON
+def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
+ try:
+ return (partial_json_parser.loads(input_str, flags), len(input_str))
+ except JSONDecodeError as e:
+ if "Extra data" in e.msg:
+ dec = JSONDecoder()
+ return dec.raw_decode(input_str)
+ raise
+
+
+def is_complete_json(input_str: str) -> bool:
+ try:
+ json.loads(input_str)
+ return True
+ except JSONDecodeError:
+ return False
+
+
+def consume_space(i: int, s: str) -> int:
+ while i < len(s) and s[i].isspace():
+ i += 1
+ return i
diff --git a/fastdeploy/input/ernie_processor.py b/fastdeploy/input/ernie_processor.py
index a268ad562..7f55c26ce 100644
--- a/fastdeploy/input/ernie_processor.py
+++ b/fastdeploy/input/ernie_processor.py
@@ -43,7 +43,7 @@ class ErnieProcessor(BaseDataProcessor):
pad_token_id (int): 存储填充符号的token ID。
"""
- def __init__(self, model_name_or_path, reasoning_parser_obj=None):
+ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
self.model_name_or_path = model_name_or_path
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
@@ -63,6 +63,7 @@ class ErnieProcessor(BaseDataProcessor):
self.reasoning_parser = None
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
+ self.tool_parser_obj = tool_parser_obj
def _init_config(self):
self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1
@@ -204,6 +205,12 @@ class ErnieProcessor(BaseDataProcessor):
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
+ if self.tool_parser_obj:
+ tool_parser = self.tool_parser_obj(self.tokenizer)
+ tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
+ if tool_call_info.tools_called:
+ response_dict.outputs.tool_calls = tool_call_info.tool_calls
+ response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
@@ -244,12 +251,20 @@ class ErnieProcessor(BaseDataProcessor):
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
- if enable_thinking and self.reasoning_parser:
+ if self.reasoning_parser and (
+ enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
+ ):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
+ if self.tool_parser_obj:
+ tool_parser = self.tool_parser_obj(self.tokenizer)
+ tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
+ if tool_call_info.tools_called:
+ response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
+ response_dict["outputs"]["text"] = tool_call_info.content
response_dict["outputs"]["raw_prediction"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
@@ -274,7 +289,9 @@ class ErnieProcessor(BaseDataProcessor):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
- if enable_thinking and self.reasoning_parser:
+ if self.reasoning_parser and (
+ enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
+ ):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -287,10 +304,25 @@ class ErnieProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
- response_dict["outputs"]["raw_prediction"] = delta_text
+ if self.tool_parser_obj:
+ if req_id not in self.tool_parsers:
+ self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
+ tool_parser = self.tool_parsers[req_id]
+ tool_call = tool_parser.extract_tool_calls_streaming(
+ previous_texts,
+ previous_texts + delta_text,
+ delta_text,
+ previous_token_ids,
+ previous_token_ids + token_ids,
+ token_ids,
+ response_dict,
+ )
+ response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
+ if req_id in self.tool_parsers:
+ del self.tool_parsers[req_id]
return response_dict
def messages2ids(self, request_or_messages):
diff --git a/fastdeploy/input/ernie_vl_processor.py b/fastdeploy/input/ernie_vl_processor.py
index d2975c697..21a96e92b 100644
--- a/fastdeploy/input/ernie_vl_processor.py
+++ b/fastdeploy/input/ernie_vl_processor.py
@@ -34,6 +34,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
+ tool_parser_obj=None,
):
self.use_hf_tokenizer = False
@@ -53,6 +54,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.image_patch_id = self.ernie_processor.image_patch_id
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
+ self.tool_parsers = dict()
self.decode_status = dict()
self._load_tokenizer()
self.eos_token_ids = [self.tokenizer.eos_token_id]
@@ -62,6 +64,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.reasoning_parser = None
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
+ self.tool_parser_obj = tool_parser_obj
# Generation config
try:
diff --git a/fastdeploy/input/preprocess.py b/fastdeploy/input/preprocess.py
index 8edd4eb4b..5c1e2e802 100644
--- a/fastdeploy/input/preprocess.py
+++ b/fastdeploy/input/preprocess.py
@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.config import ModelConfig
+from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.reasoning import ReasoningParserManager
@@ -48,6 +49,7 @@ class InputPreprocessor:
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
enable_mm: bool = False,
+ tool_parser: str = None,
) -> None:
self.model_name_or_path = model_name_or_path
@@ -55,6 +57,7 @@ class InputPreprocessor:
self.enable_mm = enable_mm
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
+ self.tool_parser = tool_parser
def create_processor(self):
"""
@@ -68,8 +71,11 @@ class InputPreprocessor:
DataProcessor or MultiModalRegistry.Processor (Union[DataProcessor, MultiModalRegistry.Processor]): 数据处理器。
"""
reasoning_parser_obj = None
+ tool_parser_obj = None
if self.reasoning_parser:
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
+ if self.tool_parser:
+ tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
if not self.enable_mm:
if not ErnieArchitectures.contains_ernie_arch(architectures):
@@ -78,6 +84,7 @@ class InputPreprocessor:
self.processor = DataProcessor(
model_name_or_path=self.model_name_or_path,
reasoning_parser_obj=reasoning_parser_obj,
+ tool_parser_obj=tool_parser_obj,
)
else:
from fastdeploy.input.ernie_processor import ErnieProcessor
@@ -85,6 +92,7 @@ class InputPreprocessor:
self.processor = ErnieProcessor(
model_name_or_path=self.model_name_or_path,
reasoning_parser_obj=reasoning_parser_obj,
+ tool_parser_obj=tool_parser_obj,
)
else:
if not architectures.startswith("Ernie4_5_VLMoeForConditionalGeneration"):
@@ -97,5 +105,6 @@ class InputPreprocessor:
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
+ tool_parser_obj=tool_parser_obj,
)
return self.processor
diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py
index eec346341..4bffee280 100644
--- a/fastdeploy/input/text_processor.py
+++ b/fastdeploy/input/text_processor.py
@@ -148,7 +148,7 @@ class BaseDataProcessor(ABC):
class DataProcessor(BaseDataProcessor):
- def __init__(self, model_name_or_path, reasoning_parser_obj=None):
+ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
"""
Initializes the DecodeStatus object.
@@ -168,6 +168,7 @@ class DataProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
+ self.tool_parsers = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -180,6 +181,7 @@ class DataProcessor(BaseDataProcessor):
self.eos_token_id_len = len(self.eos_token_ids)
self.pad_token_id = self.get_pad_id()
self.reasoning_parser = None
+ self.tool_parser_obj = tool_parser_obj
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
@@ -329,6 +331,12 @@ class DataProcessor(BaseDataProcessor):
else:
# 模型不支持思考,并且没单独设置enable_thinking为false
response_dict.outputs.text = full_text
+ if self.tool_parser_obj:
+ tool_parser = self.tool_parser_obj(self.tokenizer)
+ tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
+ if tool_call_info.tools_called:
+ response_dict.outputs.tool_calls = tool_call_info.tool_calls
+ response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
return response_dict
@@ -360,6 +368,12 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
+ if self.tool_parser_obj:
+ tool_parser = self.tool_parser_obj(self.tokenizer)
+ tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
+ if tool_call_info.tools_called:
+ response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
+ response_dict["outputs"]["text"] = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
return response_dict
@@ -397,9 +411,25 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
+ if self.tool_parser_obj and not is_end:
+ if req_id not in self.tool_parsers:
+ self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
+ tool_parser = self.tool_parsers[req_id]
+ tool_call = tool_parser.extract_tool_calls_streaming(
+ previous_texts,
+ previous_texts + delta_text,
+ delta_text,
+ previous_token_ids,
+ previous_token_ids + token_ids,
+ token_ids,
+ response_dict,
+ )
+ response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
+ if req_id in self.tool_parsers:
+ del self.tool_parsers[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):
diff --git a/fastdeploy/reasoning/__init__.py b/fastdeploy/reasoning/__init__.py
index aa7d65e50..51f59776e 100644
--- a/fastdeploy/reasoning/__init__.py
+++ b/fastdeploy/reasoning/__init__.py
@@ -16,6 +16,7 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
+from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
from .qwen3_reasoning_parsers import Qwen3ReasoningParser
__all__ = [
@@ -23,4 +24,5 @@ __all__ = [
"ReasoningParserManager",
"ErnieVLReasoningParser",
"Qwen3ReasoningParser",
+ "ErnieX1ReasoningParser",
]
diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
new file mode 100644
index 000000000..458505252
--- /dev/null
+++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+#
+from collections.abc import Sequence
+from typing import Tuple
+
+from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
+from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
+
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+@ReasoningParserManager.register_module("ernie_x1")
+class ErnieX1ReasoningParser(ReasoningParser):
+ """
+ Reasoning parser for ernie_x1 model with stricter boundary checking.
+
+ This implementation follows the user's proposed approach:
+ 1. For thinking content: waits for \n then checks for tag
+ 2. For response content: checks for tag first, then waits for \n
+ 3. Handles newlines in content more precisely
+ """
+
+ def __init__(self, tokenizer):
+ super().__init__(tokenizer)
+ self.think_end_token = ""
+ self.response_start_token = ""
+ self.response_end_token = ""
+ self.tool_call_start_token = ""
+ self.tool_call_end_token = ""
+
+ if not self.model_tokenizer:
+ raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
+
+ self.think_end_token_id = self.vocab.get("")
+ if self.think_end_token_id is None:
+ raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
+
+ def extract_reasoning_content_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ ) -> tuple[str, str]:
+ """
+ 根据用户需求实现的流式解析方法:
+ 1. 初始内容都视为思考内容
+ 2. 当遇到\n时检查后续是否是
+ 3. 思考结束后检查是还是
+ 4. 对于内容,处理换行和结束标记
+ """
+ # 如果还在思考阶段
+ if not previous_text.endswith(self.think_end_token):
+ # 如果遇到\n后接或直接遇到,思考结束
+ if (previous_text.endswith("\n") and delta_text == self.think_end_token) or (
+ not previous_text.endswith("\n") and delta_text == self.think_end_token
+ ):
+ return "", ""
+ # 否则继续返回思考内容
+ return delta_text, ""
+
+ # 思考结束后检查是tool_call还是response
+ remaining_text = previous_text + delta_text
+ after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
+
+ # 跳过think后的换行
+ after_think = after_think.lstrip("\n")
+
+ # 处理tool_call情况
+ if after_think.startswith(self.tool_call_start_token):
+ return "", ""
+
+ # 处理response情况
+ if after_think.startswith(self.response_start_token):
+ response_content = after_think[len(self.response_start_token) :]
+ # 跳过response后的换行
+ response_content = response_content.lstrip("\n")
+
+ # 检查response是否结束
+ if response_content.endswith(self.response_end_token):
+ return "", ""
+
+ # 返回response内容(使用delta_text确保流式输出)
+ return "", delta_text
+
+ # 默认情况不返回内容
+ return "", ""
+
+ def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
+ """
+ Batch version of the enhanced parser.
+ Modified to preserve newlines in both reasoning and response content,
+ only removing the single newline before closing tags.
+ """
+ reasoning_content = ""
+ response_content = ""
+
+ think_end_pos = model_output.find(self.think_end_token)
+ if think_end_pos != -1:
+ # Extract thinking content - only remove the last newline before
+ reasoning_content = model_output[:think_end_pos]
+ if think_end_pos > 0 and reasoning_content[-1] == "\n":
+ reasoning_content = reasoning_content[:-1]
+
+ remaining = model_output[think_end_pos + len(self.think_end_token) :]
+
+ # Skip newlines after
+ remaining = remaining.lstrip("\n")
+
+ # Check for response or tool_call
+ if remaining.startswith(self.response_start_token):
+ response_pos = len(self.response_start_token)
+ remaining = remaining[response_pos:].lstrip("\n")
+ response_end_pos = remaining.find(self.response_end_token)
+ if response_end_pos != -1:
+ # Only strip the last newline before , not all
+ if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
+ response_content = remaining[: response_end_pos - 1]
+ else:
+ response_content = remaining[:response_end_pos]
+ else:
+ # If no found, return the rest as response content
+ response_content = remaining
+ elif remaining.startswith(self.tool_call_start_token):
+ pass # No response content
+ else:
+ # No thinking content found, return the whole input as reasoning
+ reasoning_content = model_output
+ response_content = ""
+ return reasoning_content, response_content
+
+
+import unittest
+from unittest.mock import MagicMock
+
+
+class TestErnieX1ReasoningParser(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = MagicMock()
+ self.tokenizer.vocab = {
+ "\n\n\n": 1001,
+ "\n": 1002,
+ "\n\n": 1003,
+ "\n": 1004,
+ "\n\n": 1005,
+ }
+ self.parser = ErnieX1ReasoningParser(self.tokenizer)
+
+ def test_streaming_with_think_and_response(self):
+ # 测试标准情况:\n\n\n\ncontent\n\n
+ prev_text = "thinking"
+ delta_text = "\n\n\n\nanswer\n\n"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
+ self.assertEqual(result, ("thinking", "answer"))
+
+ def test_streaming_with_think_and_tool_call(self):
+ # 测试tool_call情况
+ prev_text = "thinking"
+ delta_text = "\n\n\n\ndetails\n\n"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
+ self.assertEqual(result, ("thinking", ""))
+
+ def test_streaming_with_think_no_newline(self):
+ # 测试没有前置换行的情况
+ prev_text = "thinking"
+ delta_text = "\n\nanswer\n"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
+ self.assertEqual(result, ("thinking", "answer"))
+
+ def test_streaming_response_without_leading_newline(self):
+ # 测试response内容没有前置换行
+ prev_text = "thinking\n\n\n"
+ delta_text = "answer\n\n"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
+ self.assertEqual(result, ("thinking", "answer"))
+
+ def test_streaming_response_with_middle_newline(self):
+ # 测试response内容中间的换行符
+ prev_text = "thinking\n\n\n\n"
+ delta_text = "line1\nline2\n\n"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
+ self.assertEqual(result, ("thinking", "line1\nline2"))
+
+ def test_streaming_partial_response(self):
+ # 测试不完整的response流式输出
+ prev_text = "thinking\n\n\n\n"
+ delta_text = "partial answer"
+ result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
+ self.assertEqual(result, ("thinking", "partial answer"))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py
index 5d68c7681..70e5df129 100644
--- a/fastdeploy/utils.py
+++ b/fastdeploy/utils.py
@@ -23,6 +23,7 @@ import os
import random
import re
import socket
+import sys
import tarfile
import time
from datetime import datetime
@@ -591,6 +592,22 @@ def is_list_of(
assert_never(check)
+def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
+ """
+ Import a Python file according to its file path.
+ """
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ if spec is None:
+ raise ModuleNotFoundError(f"No module named '{module_name}'")
+
+ assert spec.loader is not None
+
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
def version():
"""
Prints the contents of the version.txt file located in the parent directory of this script.
diff --git a/requirements.txt b/requirements.txt
index 55489db3a..0e0d5ca6f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql
opentelemetry-distro
opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
+partial_json_parser