mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] add tool parser (#3518)
* [Feature] Pass through the `chat_template_kwargs` to the data processing module (#3421)
* fix chat_template_args
* fix args
* add offline
* add offline
* fix
* fix
* fix default enable_thinking value
* fix default enable_thinking value
* modify condition
* Revert "modify condition"
This reverts commit 26430bdeb1
.
* fix unit test
* add Tool Parser (#3272)
* add tool-parser
* add tool-parser
* add tool parser
* add tool parser
* fix
* add offline
* add offline
* fix
* parsers:tool&reasoning
* 修改tool parser名称·
* update
* fix reasoning-parser
* add requirements
* fix finish reason
* fix
* fix reasoning-parser
* fix
* fix
* fix
* fix
* fix
---------
Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
* [Feature] add tool parser (#3483)
* add tool parser
* add x1 enable_thinking
* restart ci
* fix vl reasoning parser
* modify call style
* modify call style
* add offline enablethinking
* fix completion
* fix
* fix unit test
* fix unit test
* fix unit test
* fix vl reasoning parser
* fix vl reasoning parser
* fix unit test
---------
Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
This commit is contained in:
@@ -46,6 +46,8 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |
|
||||
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
|
||||
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
|
||||
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
|
||||
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
|
||||
|
||||
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
|
||||
|
||||
|
@@ -44,6 +44,8 @@
|
||||
| ```dynamic_load_weight``` | `int` | 是否动态加载权重,默认0 |
|
||||
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
|
||||
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 |
|
||||
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容|
|
||||
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式|
|
||||
|
||||
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
|
||||
|
||||
|
@@ -95,6 +95,14 @@ class EngineArgs:
|
||||
"""
|
||||
specifies the reasoning parser to use for extracting reasoning content from the model output
|
||||
"""
|
||||
tool_call_parser: str = None
|
||||
"""
|
||||
specifies the tool call parser to use for extracting tool call from the model output
|
||||
"""
|
||||
tool_parser_plugin: str = None
|
||||
"""
|
||||
tool parser plugin used to register user defined tool parsers
|
||||
"""
|
||||
enable_mm: bool = False
|
||||
"""
|
||||
Flags to enable multi-modal model
|
||||
@@ -423,6 +431,18 @@ class EngineArgs:
|
||||
help="Flag specifies the reasoning parser to use for extracting "
|
||||
"reasoning content from the model output",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
default=EngineArgs.tool_call_parser,
|
||||
help="Flag specifies the tool call parser to use for extracting" "tool call from the model output",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default=EngineArgs.tool_parser_plugin,
|
||||
help="tool parser plugin used to register user defined tool parsers",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--speculative-config",
|
||||
type=json.loads,
|
||||
@@ -913,6 +933,7 @@ class EngineArgs:
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
enable_mm=self.enable_mm,
|
||||
reasoning_parser=self.reasoning_parser,
|
||||
tool_parser=self.tool_call_parser,
|
||||
splitwise_role=self.splitwise_role,
|
||||
innode_prefill_ports=self.innode_prefill_ports,
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
|
@@ -85,6 +85,7 @@ class Config:
|
||||
max_long_partial_prefills: int = 1,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
reasoning_parser: str = None,
|
||||
tool_parser: str = None,
|
||||
guided_decoding_backend: Optional[str] = None,
|
||||
disable_any_whitespace: bool = False,
|
||||
enable_logprob: bool = False,
|
||||
@@ -165,6 +166,7 @@ class Config:
|
||||
self.max_long_partial_prefills = max_long_partial_prefills
|
||||
self.long_prefill_token_threshold = long_prefill_token_threshold
|
||||
self.reasoning_parser = reasoning_parser
|
||||
self.tool_parser = tool_parser
|
||||
self.graph_optimization_config = graph_optimization_config
|
||||
self.early_stop_config = early_stop_config
|
||||
self.guided_decoding_backend = guided_decoding_backend
|
||||
|
@@ -106,6 +106,7 @@ class LLMEngine:
|
||||
cfg.limit_mm_per_prompt,
|
||||
cfg.mm_processor_kwargs,
|
||||
cfg.enable_mm,
|
||||
cfg.tool_parser,
|
||||
)
|
||||
|
||||
self.start_queue_service()
|
||||
|
@@ -24,6 +24,7 @@ from typing import Any, Dict, Optional, Union
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.openai.protocol import ToolCall
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
|
||||
|
||||
@@ -249,6 +250,7 @@ class CompletionOutput:
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
tool_calls: Optional[ToolCall] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
|
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from typing import List, Literal, Union
|
||||
from urllib.parse import urlparse
|
||||
@@ -156,3 +157,7 @@ def parse_chat_messages(messages):
|
||||
|
||||
conversation.append({"role": role, "content": parsed_content})
|
||||
return conversation
|
||||
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
@@ -45,6 +45,7 @@ class EngineClient:
|
||||
data_parallel_size=1,
|
||||
enable_logprob=False,
|
||||
workers=1,
|
||||
tool_parser=None,
|
||||
):
|
||||
input_processor = InputPreprocessor(
|
||||
tokenizer,
|
||||
@@ -52,6 +53,7 @@ class EngineClient:
|
||||
limit_mm_per_prompt,
|
||||
mm_processor_kwargs,
|
||||
enable_mm,
|
||||
tool_parser,
|
||||
)
|
||||
self.enable_logprob = enable_logprob
|
||||
self.enable_mm = enable_mm
|
||||
|
@@ -28,8 +28,7 @@ from tqdm import tqdm
|
||||
from fastdeploy.engine.args_utils import EngineArgs
|
||||
from fastdeploy.engine.engine import LLMEngine
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from fastdeploy.utils import llm_logger, retrive_model_from_server
|
||||
from fastdeploy.worker.output import Logprob, LogprobsLists
|
||||
|
||||
@@ -73,6 +72,9 @@ class LLM:
|
||||
**kwargs,
|
||||
):
|
||||
model = retrive_model_from_server(model, revision)
|
||||
tool_parser_plugin = kwargs.get("tool_parser_plugin")
|
||||
if tool_parser_plugin:
|
||||
ToolParserManager.import_tool_parser(tool_parser_plugin)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
|
@@ -41,6 +41,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from fastdeploy.metrics.metrics import (
|
||||
EXCLUDE_LABELS,
|
||||
cleanup_prometheus_files,
|
||||
@@ -73,7 +74,8 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
|
||||
if args.tool_parser_plugin:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
llm_engine = None
|
||||
|
||||
|
||||
@@ -126,6 +128,7 @@ async def lifespan(app: FastAPI):
|
||||
args.data_parallel_size,
|
||||
args.enable_logprob,
|
||||
args.workers,
|
||||
args.tool_call_parser,
|
||||
)
|
||||
app.state.dynamic_load_weight = args.dynamic_load_weight
|
||||
chat_handler = OpenAIServingChat(engine_client, pid, args.ips, args.max_waiting_time)
|
||||
|
@@ -72,7 +72,6 @@ class ToolCall(BaseModel):
|
||||
id: str = None
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
index: int
|
||||
|
||||
|
||||
class DeltaFunctionCall(BaseModel):
|
||||
@@ -96,6 +95,18 @@ class DeltaToolCall(BaseModel):
|
||||
function: Optional[DeltaFunctionCall] = None
|
||||
|
||||
|
||||
class ExtractedToolCallInformation(BaseModel):
|
||||
# indicate if tools were called
|
||||
tools_called: bool
|
||||
|
||||
# extracted tool calls
|
||||
tool_calls: Optional[list[ToolCall]] = None
|
||||
|
||||
# content - per OpenAI spec, content AND tool calls can be returned rarely
|
||||
# But some models will do this intentionally
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class FunctionDefinition(BaseModel):
|
||||
"""
|
||||
Function definition.
|
||||
|
@@ -141,6 +141,7 @@ class OpenAIServingChat:
|
||||
previous_num_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
num_choices = 1
|
||||
tool_called = False
|
||||
max_streaming_response_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
@@ -246,20 +247,29 @@ class OpenAIServingChat:
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
logprobs_res: Optional[LogProbs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
reasoning_content="",
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=output.get("tool_call_content", []),
|
||||
tool_calls=None,
|
||||
)
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
continue
|
||||
delta_message.content = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
if delta_message_output.tool_calls:
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
tool_called = True
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -277,10 +287,7 @@ class OpenAIServingChat:
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if (
|
||||
self.engine_client.reasoning_parser == "ernie_x1"
|
||||
and output.get("finish_reason", "") == "tool_calls"
|
||||
):
|
||||
if tool_called:
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
@@ -421,7 +428,7 @@ class OpenAIServingChat:
|
||||
role="assistant",
|
||||
content=output["text"],
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
tool_calls=output.get("tool_call"),
|
||||
prompt_token_ids=prompt_token_ids if request.return_token_ids else None,
|
||||
completion_token_ids=completion_token_ids if request.return_token_ids else None,
|
||||
text_after_process=text_after_process if request.return_token_ids else None,
|
||||
@@ -443,7 +450,7 @@ class OpenAIServingChat:
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
|
||||
if output.get("tool_call"):
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
@@ -240,9 +240,9 @@ class OpenAIServingCompletion:
|
||||
dealer.close()
|
||||
self.engine_client.semaphore.release()
|
||||
|
||||
def calc_finish_reason(self, max_tokens, token_num, output):
|
||||
def calc_finish_reason(self, max_tokens, token_num, output, tool_called):
|
||||
if max_tokens is None or token_num != max_tokens:
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
|
||||
if tool_called or output.get("tool_call"):
|
||||
return "tool_calls"
|
||||
else:
|
||||
return "stop"
|
||||
@@ -271,6 +271,7 @@ class OpenAIServingCompletion:
|
||||
output_tokens = [0] * num_choices
|
||||
inference_start_time = [0] * num_choices
|
||||
first_iteration = [True] * num_choices
|
||||
tool_called = [False] * num_choices
|
||||
max_streaming_response_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
@@ -343,25 +344,34 @@ class OpenAIServingCompletion:
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
choices.append(
|
||||
CompletionResponseStreamChoice(
|
||||
output_tokens[idx] += 1
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=None,
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
reasoning_content="",
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
)
|
||||
output_tokens[idx] += 1
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
continue
|
||||
delta_message.text = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
if delta_message_output.tool_calls:
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
tool_called[idx] = True
|
||||
|
||||
choices.append(delta_message)
|
||||
|
||||
if res["finished"]:
|
||||
choices[-1].finish_reason = self.calc_finish_reason(
|
||||
request.max_tokens, output_tokens[idx], output
|
||||
request.max_tokens, output_tokens[idx], output, tool_called[idx]
|
||||
)
|
||||
send_idx = output.get("send_idx")
|
||||
# 只有当 send_idx 明确为 0 时才记录日志
|
||||
@@ -460,7 +470,7 @@ class OpenAIServingCompletion:
|
||||
token_ids = output["token_ids"]
|
||||
output_text = output["text"]
|
||||
|
||||
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output)
|
||||
finish_reason = self.calc_finish_reason(request.max_tokens, final_res["output_token_ids"], output, False)
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
token_ids=token_ids,
|
||||
@@ -473,7 +483,7 @@ class OpenAIServingCompletion:
|
||||
text_after_process=text_after_process_list[idx] if request.return_token_ids else None,
|
||||
prompt_tokens=text_after_process_list[idx] if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
tool_calls=output.get("tool_call"),
|
||||
logprobs=aggregated_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
24
fastdeploy/entrypoints/openai/tool_parsers/__init__.py
Normal file
24
fastdeploy/entrypoints/openai/tool_parsers/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .ernie_x1_tool_parser import ErnieX1ToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"ErnieX1ToolParser",
|
||||
]
|
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
)
|
||||
from fastdeploy.utils import data_processor_logger, import_from_path, is_list_of
|
||||
|
||||
|
||||
class ToolParser:
|
||||
"""
|
||||
Abstract ToolParser class that should not be used directly. Provided
|
||||
properties and methods should be used in
|
||||
derived classes.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
# the index of the tool call that is currently being parsed
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
# NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
Used for non-streaming responses where we have the entire model response
|
||||
available before sending to the client.
|
||||
Static because it's stateless.
|
||||
"""
|
||||
raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!")
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
from an incomplete response; for use when handling tool calls and
|
||||
streaming. Has to be an instance method because it requires state -
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(
|
||||
cls, module: type, module_name: Optional[Union[str, list[str]]] = None, force: bool = True
|
||||
) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}")
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f"{name} is already registered " f"at {existed_module.__module__}")
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls, name: Optional[Union[str, list[str]]] = None, force: bool = True, module: Union[type, None] = None
|
||||
) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f"force must be a boolean, but got {type(force)}")
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str) or is_list_of(name, str)):
|
||||
raise TypeError("name must be None, an instance of str, or a sequence of str, " f"but got {type(name)}")
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user-defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
|
||||
try:
|
||||
import_from_path(module_name, plugin_path)
|
||||
except Exception:
|
||||
data_processor_logger.exception("Failed to load module '%s' from %s.", module_name, plugin_path)
|
||||
return
|
@@ -0,0 +1,347 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import partial_json_parser
|
||||
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
"""Generate a random tool call ID"""
|
||||
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
|
||||
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from fastdeploy.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
|
||||
|
||||
@ToolParserManager.register_module("ernie_x1")
|
||||
class ErnieX1ToolParser(ToolParser):
|
||||
"""
|
||||
Tool parser for Ernie model version 4.5.1.
|
||||
This parser handles tool calls with newline formats.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
|
||||
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
|
||||
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolCallParser constructor during construction."
|
||||
)
|
||||
|
||||
def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
Supports XML-style formats with newlines:
|
||||
- XML format: <think>\n...\n</think>\n\n\n<tool_call>\n{...}\n</tool_call>\n...
|
||||
|
||||
Handles boundary cases:
|
||||
1. Only name and partial arguments: {"name": "get_weather", "arguments": {"location": "北京"
|
||||
2. Only partial name: {"name": "get_we
|
||||
3. Only name and arguments field without content: {"name": "get_weather", "argume
|
||||
"""
|
||||
|
||||
try:
|
||||
tool_calls = []
|
||||
|
||||
# Check for invalid <response> tags before tool calls
|
||||
if re.search(r"<response>[\s\S]*?</response>\s*(?=<tool_call>)", model_output):
|
||||
data_processor_logger.error("Invalid format: <response> tags found before <tool_call>")
|
||||
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
||||
|
||||
function_call_arr = []
|
||||
remaining_text = model_output
|
||||
|
||||
while True:
|
||||
# 查找下一个tool_call块
|
||||
tool_call_pos = remaining_text.find("<tool_call>")
|
||||
if tool_call_pos == -1:
|
||||
break
|
||||
|
||||
# 提取tool_call开始位置后的内容
|
||||
tool_content_start = tool_call_pos + len("<tool_call>")
|
||||
tool_content_end = remaining_text.find("</tool_call>", tool_content_start)
|
||||
|
||||
tool_json = ""
|
||||
if tool_content_end == -1:
|
||||
# 处理未闭合的tool_call块(截断情况)
|
||||
tool_json = remaining_text[tool_content_start:].strip()
|
||||
remaining_text = "" # 没有更多内容需要处理
|
||||
else:
|
||||
# 处理完整的tool_call块
|
||||
tool_json = remaining_text[tool_content_start:tool_content_end].strip()
|
||||
remaining_text = remaining_text[tool_content_end + len("</tool_call>") :]
|
||||
|
||||
if not tool_json:
|
||||
continue
|
||||
|
||||
# 处理JSON内容
|
||||
tool_json = tool_json.strip()
|
||||
if not tool_json.startswith("{"):
|
||||
tool_json = "{" + tool_json
|
||||
if not tool_json.endswith("}"):
|
||||
tool_json = tool_json + "}"
|
||||
|
||||
try:
|
||||
# 首先尝试标准JSON解析
|
||||
try:
|
||||
tool_data = json.loads(tool_json)
|
||||
|
||||
if isinstance(tool_data, dict) and "name" in tool_data and "arguments" in tool_data:
|
||||
function_call_arr.append(
|
||||
{
|
||||
"name": tool_data["name"],
|
||||
"arguments": tool_data["arguments"],
|
||||
"_is_complete": True, # 明确标记为完整解析
|
||||
}
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 标准解析失败时尝试partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
try:
|
||||
tool_data = {}
|
||||
flags = Allow.ALL & ~Allow.STR
|
||||
|
||||
# 解析name字段
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json)
|
||||
if name_match:
|
||||
tool_data["name"] = name_match.group(1)
|
||||
|
||||
# 解析arguments字段
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json)
|
||||
if args_match:
|
||||
try:
|
||||
tool_data["arguments"] = partial_json_parser.loads(args_match.group(1), flags=flags)
|
||||
except:
|
||||
tool_data["arguments"] = None
|
||||
|
||||
if isinstance(tool_data, dict):
|
||||
function_call_arr.append(
|
||||
{
|
||||
"name": tool_data.get("name", ""),
|
||||
"arguments": tool_data.get("arguments", {}),
|
||||
"_is_partial": True, # 标记为部分解析
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
|
||||
if not function_call_arr:
|
||||
data_processor_logger.error("No valid tool calls found")
|
||||
return ExtractedToolCallInformation(tools_called=False, content=model_output)
|
||||
|
||||
tool_calls = []
|
||||
all_complete = True # 初始设为True,只要有一个不完整就变为False
|
||||
|
||||
for tool_call in function_call_arr:
|
||||
# 记录工具调用解析状态
|
||||
is_complete = tool_call.get("_is_complete", False)
|
||||
is_partial = tool_call.get("_is_partial", False)
|
||||
|
||||
# 只要有一个不完整就认为整体不完整
|
||||
if not is_complete or is_partial:
|
||||
all_complete = False
|
||||
|
||||
# 处理参数序列化
|
||||
tool_args = tool_call.get("arguments", {})
|
||||
if not isinstance(tool_args, dict):
|
||||
tool_args = {}
|
||||
|
||||
try:
|
||||
args_str = json.dumps(tool_args, ensure_ascii=False) if tool_args else "{}"
|
||||
except:
|
||||
args_str = "{}"
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=FunctionCall(
|
||||
name=tool_call.get("name", ""),
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# 只有当所有工具调用都明确标记为complete时才返回tools_called=True
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content=""
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
|
||||
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: dict,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
# 忽略空chunk
|
||||
if len(delta_text.strip()) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
delta = None
|
||||
# 使用buffer累积delta_text内容
|
||||
self.buffer += delta_text
|
||||
|
||||
# 处理增量中的新tool_call开始
|
||||
if "<tool_call>" in delta_text:
|
||||
self.current_tool_id = (
|
||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||
)
|
||||
self.current_tool_name_sent = False
|
||||
if len(self.streamed_args_for_tool) <= self.current_tool_id:
|
||||
self.streamed_args_for_tool.append("")
|
||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||
|
||||
# 1. 尝试解析name字段
|
||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||
if name_match:
|
||||
name = name_match.group(1)
|
||||
if name:
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
function=DeltaFunctionCall(name=name).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
# 删除已处理的name部分
|
||||
self.buffer = self.buffer[name_match.end() :]
|
||||
self.current_tool_name_sent = True
|
||||
return delta
|
||||
# 2. 尝试解析arguments字段
|
||||
if '"arguments"' in self.buffer:
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||
if args_match:
|
||||
args_content = args_match.group(1)
|
||||
try:
|
||||
# 检查是否到达arguments结尾(括号完全匹配)
|
||||
if "}}" in args_content:
|
||||
# 逐个字符检查括号匹配状态
|
||||
matched_pos = -1
|
||||
for i, ch in enumerate(delta_text):
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
|
||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
|
||||
matched_pos = i
|
||||
break
|
||||
|
||||
if matched_pos >= 0:
|
||||
# 找到匹配点,清理buffer并返回
|
||||
truncate_text = delta_text[: matched_pos + 1]
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.buffer = self.buffer[args_match.end() :]
|
||||
return delta
|
||||
else:
|
||||
# 没有完全匹配,继续累积
|
||||
return None
|
||||
else:
|
||||
# 增量返回当前可解析的部分
|
||||
for ch in delta_text:
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
if "</tool_call>" in self.buffer:
|
||||
end_pos = self.buffer.find("</tool_call>")
|
||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||
|
||||
# 完成当前工具调用处理
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
137
fastdeploy/entrypoints/openai/tool_parsers/utils.py
Normal file
137
fastdeploy/entrypoints/openai/tool_parsers/utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
|
||||
def find_common_prefix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common prefix that is shared between two strings, if there is one.
|
||||
Order of arguments is NOT important.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely.
|
||||
|
||||
e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
|
||||
'{"fruit": "ap'
|
||||
"""
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def find_common_suffix(s1: str, s2: str) -> str:
|
||||
"""
|
||||
Finds a common suffix shared between two strings, if there is one. Order of
|
||||
arguments is NOT important.
|
||||
Stops when the suffix ends OR it hits an alphanumeric character
|
||||
|
||||
e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
|
||||
"""
|
||||
suffix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(1, min_length + 1):
|
||||
if s1[-i] == s2[-i] and not s1[-i].isalnum():
|
||||
suffix = s1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
return suffix
|
||||
|
||||
|
||||
def extract_intermediate_diff(curr: str, old: str) -> str:
|
||||
"""
|
||||
Given two strings, extract the difference in the middle between two strings
|
||||
that are known to have a common prefix and/or suffix.
|
||||
|
||||
This function is provided as a UTILITY for extracting information from JSON
|
||||
generated by partial_json_parser, to help in ensuring that the right tokens
|
||||
are returned in streaming, so that close-quotes, close-brackets and
|
||||
close-braces are not returned prematurely. The order of arguments IS
|
||||
important - the new version of the partially-parsed JSON must be the first
|
||||
argument, and the secnod argument must be from the previous generation.
|
||||
|
||||
What it returns, is tokens that should be streamed to the client.
|
||||
|
||||
e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
|
||||
-> 'ple'
|
||||
|
||||
"""
|
||||
suffix = find_common_suffix(curr, old)
|
||||
|
||||
old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
prefix = find_common_prefix(curr, old)
|
||||
diff = curr
|
||||
if len(suffix):
|
||||
diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
|
||||
|
||||
if len(prefix):
|
||||
# replace the prefix only once in case it's mirrored
|
||||
diff = diff.replace(prefix, "", 1)
|
||||
|
||||
return diff
|
||||
|
||||
|
||||
def find_all_indices(string: str, substring: str) -> list[int]:
|
||||
"""
|
||||
Find all (starting) indices of a substring in a given string. Useful for
|
||||
tool call extraction
|
||||
"""
|
||||
indices = []
|
||||
index = -1
|
||||
while True:
|
||||
index = string.find(substring, index + 1)
|
||||
if index == -1:
|
||||
break
|
||||
indices.append(index)
|
||||
return indices
|
||||
|
||||
|
||||
# partial_json_parser doesn't support extra data and
|
||||
# JSONDecoder.raw_decode doesn't support partial JSON
|
||||
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
def consume_space(i: int, s: str) -> int:
|
||||
while i < len(s) and s[i].isspace():
|
||||
i += 1
|
||||
return i
|
@@ -43,13 +43,14 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
pad_token_id (int): 存储填充符号的token ID。
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
|
||||
self._init_config()
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.thinking_parser_dict = dict()
|
||||
self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
@@ -61,6 +62,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.reasoning_parser = None
|
||||
self.tool_parser_obj = tool_parser_obj
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
|
||||
@@ -133,6 +135,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request.enable_thinking = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
return request
|
||||
|
||||
@@ -194,6 +198,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request["enable_thinking"] = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
@@ -221,6 +227,12 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
else:
|
||||
response_dict.outputs.text = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
|
||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
||||
return None
|
||||
@@ -261,12 +273,20 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
|
||||
response_dict["outputs"]["text"] = tool_call_info.content
|
||||
response_dict["outputs"]["raw_prediction"] = full_text
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
@@ -292,8 +312,10 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -301,14 +323,28 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
response_dict,
|
||||
)
|
||||
if tool_call_delta_message is None or tool_call_delta_message.tool_calls:
|
||||
response_dict["outputs"]["delta_message"] = tool_call_delta_message
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages):
|
||||
|
@@ -34,6 +34,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
limit_mm_per_prompt=None,
|
||||
mm_processor_kwargs=None,
|
||||
reasoning_parser_obj=None,
|
||||
tool_parser_obj=None,
|
||||
):
|
||||
self.use_hf_tokenizer = False
|
||||
|
||||
@@ -53,6 +54,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
self.image_patch_id = self.ernie_processor.image_patch_id
|
||||
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
|
||||
|
||||
self.tool_parser_dict = dict()
|
||||
self.decode_status = dict()
|
||||
self._load_tokenizer()
|
||||
self.eos_token_ids = [self.tokenizer.eos_token_id]
|
||||
@@ -60,6 +62,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.limit_mm_per_prompt = self._parse_limits(limit_mm_per_prompt)
|
||||
self.reasoning_parser = None
|
||||
self.tool_parser_obj = tool_parser_obj
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
|
||||
|
@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from fastdeploy.config import ErnieArchitectures
|
||||
from fastdeploy.engine.config import ModelConfig
|
||||
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from fastdeploy.reasoning import ReasoningParserManager
|
||||
|
||||
|
||||
@@ -48,6 +49,7 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
enable_mm: bool = False,
|
||||
tool_parser: str = None,
|
||||
) -> None:
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
@@ -55,6 +57,7 @@ class InputPreprocessor:
|
||||
self.enable_mm = enable_mm
|
||||
self.limit_mm_per_prompt = limit_mm_per_prompt
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.tool_parser = tool_parser
|
||||
|
||||
def create_processor(self):
|
||||
"""
|
||||
@@ -68,8 +71,11 @@ class InputPreprocessor:
|
||||
DataProcessor or MultiModalRegistry.Processor (Union[DataProcessor, MultiModalRegistry.Processor]): 数据处理器。
|
||||
"""
|
||||
reasoning_parser_obj = None
|
||||
tool_parser_obj = None
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
|
||||
if self.tool_parser:
|
||||
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
|
||||
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
|
||||
if not self.enable_mm:
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
@@ -78,6 +84,7 @@ class InputPreprocessor:
|
||||
self.processor = DataProcessor(
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
else:
|
||||
from fastdeploy.input.ernie_processor import ErnieProcessor
|
||||
@@ -85,6 +92,7 @@ class InputPreprocessor:
|
||||
self.processor = ErnieProcessor(
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
else:
|
||||
if not architectures.startswith("Ernie4_5_VLMoeForConditionalGeneration"):
|
||||
@@ -97,5 +105,6 @@ class InputPreprocessor:
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
reasoning_parser_obj=reasoning_parser_obj,
|
||||
tool_parser_obj=tool_parser_obj,
|
||||
)
|
||||
return self.processor
|
||||
|
@@ -148,7 +148,7 @@ class BaseDataProcessor(ABC):
|
||||
|
||||
|
||||
class DataProcessor(BaseDataProcessor):
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
|
||||
"""
|
||||
Initializes the DecodeStatus object.
|
||||
|
||||
@@ -168,6 +168,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
self._init_config()
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||
@@ -180,6 +181,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.reasoning_parser = None
|
||||
self.tool_parser_obj = tool_parser_obj
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
@@ -345,6 +347,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
else:
|
||||
# 模型不支持思考,并且没单独设置enable_thinking为false
|
||||
response_dict.outputs.text = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
|
||||
|
||||
return response_dict
|
||||
@@ -376,6 +384,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
|
||||
response_dict["outputs"]["text"] = tool_call_info.content
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
return response_dict
|
||||
@@ -400,8 +414,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -409,13 +425,28 @@ class DataProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
response_dict,
|
||||
)
|
||||
if tool_call is None or tool_call.tool_calls:
|
||||
response_dict["outputs"]["delta_message"] = tool_call
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, **kwargs):
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .ernie_vl_reasoning_parsers import ErnieVLReasoningParser
|
||||
from .ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
|
||||
from .qwen3_reasoning_parsers import Qwen3ReasoningParser
|
||||
|
||||
__all__ = [
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"ReasoningParserManager",
|
||||
"ErnieVLReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"ErnieX1ReasoningParser",
|
||||
]
|
||||
|
@@ -46,6 +46,9 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -65,18 +68,16 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return "", ""
|
||||
return None
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
reasoning_content = ""
|
||||
content = delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
reasoning_content = delta_text
|
||||
content = ""
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
@@ -95,7 +96,6 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
# Check if the model output contains the </think> tokens.
|
||||
if self.think_end_token not in model_output:
|
||||
return "", model_output
|
||||
# Extract reasoning content from the model output.
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
|
||||
final_content = content or ""
|
||||
|
162
fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
Normal file
162
fastdeploy/reasoning/ernie_x1_reasoning_parsers.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
#
|
||||
from collections.abc import Sequence
|
||||
from typing import Tuple, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("ernie_x1")
|
||||
class ErnieX1ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for ernie_x1 model with stricter boundary checking.
|
||||
|
||||
This implementation follows the user's proposed approach:
|
||||
1. For thinking content: waits for \n then checks for </think> tag
|
||||
2. For response content: checks for <response> tag first, then waits for \n
|
||||
3. Handles newlines in content more precisely
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.think_end_token = "</think>"
|
||||
self.response_start_token = "<response>"
|
||||
self.response_end_token = "</response>"
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError("The model tokenizer must be passed to the ReasoningParser constructor.")
|
||||
|
||||
self.think_end_token_id = self.vocab.get("</think>")
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.tool_call_start_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
根据用户需求实现的流式解析方法:
|
||||
1. 初始内容都视为思考内容,返回delta_text,""
|
||||
2. 当遇到\n时检查后续是否是</think>
|
||||
3. 如果直接遇到</think>也结束思考
|
||||
4. 思考结束后检查是<response>还是<tool_call>
|
||||
5. 对于<response>内容,处理各种边界条件
|
||||
"""
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
# 思考阶段处理
|
||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
||||
# 如果遇到\n,暂时不返回,等待下一个delta_text
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</think>,结束思考
|
||||
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 如果直接遇到</think>也结束思考
|
||||
elif delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 否则继续返回思考内容
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
# 思考结束后检查是tool_call还是response
|
||||
remaining_text = previous_text + delta_text
|
||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
||||
after_think = after_think.lstrip("\n") # 跳过think后的换行
|
||||
|
||||
# 处理tool_call情况
|
||||
if after_think.startswith(self.tool_call_start_token):
|
||||
return None
|
||||
|
||||
# 处理response情况
|
||||
if after_think.startswith(self.response_start_token):
|
||||
# 遇到<response>标签时不立即返回
|
||||
if delta_text == self.response_start_token:
|
||||
return None
|
||||
# 遇到<response>后的换行符也不立即返回
|
||||
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
|
||||
return None
|
||||
# 处理回复内容中的换行符
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</response>,结束回复
|
||||
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
|
||||
return None
|
||||
# 如果直接遇到</response>也结束回复
|
||||
elif delta_text == self.response_end_token:
|
||||
return None
|
||||
# 其他情况返回实际内容
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# 默认情况不返回内容
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
||||
"""
|
||||
Batch version of the enhanced parser.
|
||||
Modified to preserve newlines in both reasoning and response content,
|
||||
only removing the single newline before closing tags.
|
||||
"""
|
||||
reasoning_content = ""
|
||||
response_content = ""
|
||||
|
||||
think_end_pos = model_output.find(self.think_end_token)
|
||||
if think_end_pos != -1:
|
||||
# Extract thinking content - only remove the last newline before </think>
|
||||
reasoning_content = model_output[:think_end_pos]
|
||||
if think_end_pos > 0 and reasoning_content[-1] == "\n":
|
||||
reasoning_content = reasoning_content[:-1]
|
||||
|
||||
remaining = model_output[think_end_pos + len(self.think_end_token) :]
|
||||
|
||||
# Skip newlines after </think>
|
||||
remaining = remaining.lstrip("\n")
|
||||
|
||||
# Check for response or tool_call
|
||||
if remaining.startswith(self.response_start_token):
|
||||
response_pos = len(self.response_start_token)
|
||||
remaining = remaining[response_pos:].lstrip("\n")
|
||||
response_end_pos = remaining.find(self.response_end_token)
|
||||
if response_end_pos != -1:
|
||||
# Only strip the last newline before </response>, not all
|
||||
if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n":
|
||||
response_content = remaining[: response_end_pos - 1]
|
||||
else:
|
||||
response_content = remaining[:response_end_pos]
|
||||
else:
|
||||
# If no </response> found, return the rest as response content
|
||||
response_content = remaining
|
||||
elif remaining.startswith(self.tool_call_start_token):
|
||||
pass # No response content
|
||||
else:
|
||||
# No thinking content found, return the whole input as reasoning
|
||||
reasoning_content = model_output
|
||||
response_content = ""
|
||||
return reasoning_content, response_content
|
@@ -48,6 +48,9 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -66,7 +69,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
# </think> in delta
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
@@ -76,28 +79,28 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
end_index = delta_token_ids.find(self.think_end_token)
|
||||
reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
# <think> in previous, </think> in delta,
|
||||
else:
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
content = content if content else None
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
# </think> in previous reasoning content continues
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
return "", delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
# <think> in previous
|
||||
elif self.think_start_token_id in previous_token_ids:
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
# <think> in delta
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
reasoning_content = delta_text[start_index + len(self.think_start_token) :]
|
||||
content = ""
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
else:
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
|
@@ -23,6 +23,7 @@ import os
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import tarfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -591,6 +592,22 @@ def is_list_of(
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Import a Python file according to its file path.
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
raise ModuleNotFoundError(f"No module named '{module_name}'")
|
||||
|
||||
assert spec.loader is not None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def version():
|
||||
"""
|
||||
Prints the contents of the version.txt file located in the parent directory of this script.
|
||||
|
@@ -37,3 +37,4 @@ opentelemetry-instrumentation-mysql
|
||||
opentelemetry-distro
|
||||
opentelemetry-exporter-otlp
|
||||
opentelemetry-instrumentation-fastapi
|
||||
partial_json_parser
|
||||
|
@@ -523,7 +523,8 @@ def test_chat_with_thinking(openai_client, capsys):
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
completion_tokens = reasoning_tokens = 1
|
||||
completion_tokens = 1
|
||||
reasoning_tokens = 0
|
||||
total_tokens = 0
|
||||
for chunk_id, chunk in enumerate(response):
|
||||
if chunk_id == 0: # the first chunk is an extra chunk
|
||||
|
@@ -18,9 +18,9 @@ class TestOpenAIServingCompletion(unittest.TestCase):
|
||||
# 创建一个OpenAIServingCompletion实例
|
||||
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
|
||||
# 创建一个模拟的output,并设置finish_reason为"tool_calls"
|
||||
output = {"finish_reason": "tool_calls"}
|
||||
output = {"tool_call": True}
|
||||
# 调用calc_finish_reason方法
|
||||
result = serving_completion.calc_finish_reason(None, 100, output)
|
||||
result = serving_completion.calc_finish_reason(None, 100, output, False)
|
||||
# 断言结果为"tool_calls"
|
||||
assert result == "tool_calls"
|
||||
|
||||
@@ -31,9 +31,9 @@ class TestOpenAIServingCompletion(unittest.TestCase):
|
||||
# 创建一个OpenAIServingCompletion实例
|
||||
serving_completion = OpenAIServingCompletion(engine_client, "pid", "ips", 360)
|
||||
# 创建一个模拟的output,并设置finish_reason为其他值
|
||||
output = {"finish_reason": "other_reason"}
|
||||
output = {"tool_call": False}
|
||||
# 调用calc_finish_reason方法
|
||||
result = serving_completion.calc_finish_reason(None, 100, output)
|
||||
result = serving_completion.calc_finish_reason(None, 100, output, False)
|
||||
# 断言结果为"stop"
|
||||
assert result == "stop"
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestOpenAIServingCompletion(unittest.TestCase):
|
||||
# 创建一个模拟的output
|
||||
output = {}
|
||||
# 调用calc_finish_reason方法
|
||||
result = serving_completion.calc_finish_reason(100, 100, output)
|
||||
result = serving_completion.calc_finish_reason(100, 100, output, False)
|
||||
# 断言结果为"length"
|
||||
assert result == "length"
|
||||
|
||||
|
@@ -15,7 +15,8 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.processor.tokenizer = MagicMock()
|
||||
self.processor.tokenizer.eos_token_id = 1
|
||||
self.processor.decode_status = {}
|
||||
self.processor.tool_parsers = {}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
|
||||
# 模拟 ids2tokens 方法
|
||||
def mock_ids2tokens(token_ids, task_id):
|
||||
@@ -31,7 +32,7 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
|
||||
# 模拟工具解析器
|
||||
self.mock_tool_parser = MagicMock()
|
||||
self.mock_tool_parser.extract_tool_calls_streaming.return_value = "tool_call"
|
||||
self.mock_tool_parser.extract_tool_calls_streaming.return_value = None
|
||||
self.mock_tool_parser_obj = MagicMock()
|
||||
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
|
||||
self.processor.tool_parser_obj = self.mock_tool_parser_obj
|
||||
|
216
tests/utils/test_custom_chat_template.py
Normal file
216
tests/utils/test_custom_chat_template.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.chat_utils import load_chat_template
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from fastdeploy.input.ernie_processor import ErnieProcessor
|
||||
from fastdeploy.input.ernie_vl_processor import ErnieMoEVLProcessor
|
||||
from fastdeploy.input.text_processor import DataProcessor
|
||||
|
||||
|
||||
class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Set up the test environment by creating an instance of the LLM class using Mock.
|
||||
"""
|
||||
self.input_chat_template = "unit test \n"
|
||||
self.mock_engine = MagicMock()
|
||||
self.tokenizer = MagicMock()
|
||||
|
||||
def test_load_chat_template_non(self):
|
||||
result = load_chat_template(None)
|
||||
self.assertEqual(None, result)
|
||||
|
||||
def test_load_chat_template_str(self):
|
||||
result = load_chat_template(self.input_chat_template)
|
||||
self.assertEqual(self.input_chat_template, result)
|
||||
|
||||
def test_load_chat_template_path(self):
|
||||
with open("chat_template", "w", encoding="utf-8") as file:
|
||||
file.write(self.input_chat_template)
|
||||
file_path = os.path.join(os.getcwd(), "chat_template")
|
||||
result = load_chat_template(file_path)
|
||||
os.remove(file_path)
|
||||
self.assertEqual(self.input_chat_template, result)
|
||||
|
||||
def test_load_chat_template_non_str_and_path(self):
|
||||
with self.assertRaises(ValueError):
|
||||
load_chat_template("unit test")
|
||||
|
||||
def test_path_with_literal_true(self):
|
||||
with self.assertRaises(TypeError):
|
||||
load_chat_template(Path("./chat_template"), is_literal=True)
|
||||
|
||||
def test_path_object_file_error(self):
|
||||
with patch("builtins.open", mock_open()) as mock_file:
|
||||
mock_file.side_effect = OSError("File error")
|
||||
with self.assertRaises(OSError):
|
||||
load_chat_template(Path("./chat_template"))
|
||||
|
||||
async def test_serving_chat(self):
|
||||
request = ChatCompletionRequest(messages=[{"role": "user", "content": "你好"}])
|
||||
self.chat_completion_handler = OpenAIServingChat(
|
||||
self.mock_engine,
|
||||
models=None,
|
||||
pid=123,
|
||||
ips=None,
|
||||
max_waiting_time=-1,
|
||||
chat_template=self.input_chat_template,
|
||||
)
|
||||
|
||||
async def mock_chat_completion_full_generator(
|
||||
request, request_id, model_name, prompt_token_ids, text_after_process
|
||||
):
|
||||
return prompt_token_ids
|
||||
|
||||
def mock_format_and_add_data(current_req_dict):
|
||||
return current_req_dict
|
||||
|
||||
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
|
||||
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
|
||||
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
|
||||
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
|
||||
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
|
||||
chat_completiom = await self.chat_completion_handler.create_chat_completion(request)
|
||||
self.assertEqual(self.input_chat_template, chat_completiom["chat_template"])
|
||||
|
||||
async def test_serving_chat_cus(self):
|
||||
request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], chat_template="hello")
|
||||
self.chat_completion_handler = OpenAIServingChat(
|
||||
self.mock_engine,
|
||||
models=None,
|
||||
pid=123,
|
||||
ips=None,
|
||||
max_waiting_time=10,
|
||||
chat_template=self.input_chat_template,
|
||||
)
|
||||
|
||||
async def mock_chat_completion_full_generator(
|
||||
request, request_id, model_name, prompt_token_ids, text_after_process
|
||||
):
|
||||
return prompt_token_ids
|
||||
|
||||
def mock_format_and_add_data(current_req_dict):
|
||||
return current_req_dict
|
||||
|
||||
self.chat_completion_handler.chat_completion_full_generator = mock_chat_completion_full_generator
|
||||
self.chat_completion_handler.engine_client.format_and_add_data = mock_format_and_add_data
|
||||
self.chat_completion_handler.engine_client.semaphore = AsyncMock()
|
||||
self.chat_completion_handler.engine_client.semaphore.acquire = AsyncMock(return_value=None)
|
||||
self.chat_completion_handler.engine_client.semaphore.status = MagicMock(return_value="mock_status")
|
||||
chat_completion = await self.chat_completion_handler.create_chat_completion(request)
|
||||
self.assertEqual("hello", chat_completion["chat_template"])
|
||||
|
||||
@patch("fastdeploy.input.ernie_vl_processor.ErnieMoEVLProcessor.__init__")
|
||||
def test_vl_processor(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
vl_processor = ErnieMoEVLProcessor()
|
||||
mock_request = Request.from_dict({"request_id": "123"})
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
vl_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
vl_processor.process_request_dict = mock_process_request
|
||||
result = vl_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.input.text_processor.DataProcessor.__init__")
|
||||
def test_text_processor_process_request(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
text_processor = DataProcessor()
|
||||
mock_request = Request.from_dict(
|
||||
{"request_id": "123", "prompt": "hi", "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||
)
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
def mock_text2ids(text, max_model_len):
|
||||
return [1]
|
||||
|
||||
text_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
text_processor.process_request_dict = mock_process_request
|
||||
text_processor.text2ids = mock_text2ids
|
||||
text_processor.eos_token_ids = [1]
|
||||
result = text_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.input.ernie_processor.ErnieProcessor.__init__")
|
||||
def test_ernie_processor_process(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
ernie_processor = ErnieProcessor()
|
||||
mock_request = Request.from_dict(
|
||||
{"request_id": "123", "messages": ["hi"], "max_tokens": 128, "temperature": 1, "top_p": 1}
|
||||
)
|
||||
|
||||
def mock_apply_default_parameters(request):
|
||||
return request
|
||||
|
||||
def mock_process_request(request, max_model_len):
|
||||
return request
|
||||
|
||||
def mock_messages2ids(text):
|
||||
return [1]
|
||||
|
||||
ernie_processor._apply_default_parameters = mock_apply_default_parameters
|
||||
ernie_processor.process_request_dict = mock_process_request
|
||||
ernie_processor.messages2ids = mock_messages2ids
|
||||
ernie_processor.eos_token_ids = [1]
|
||||
ernie_processor.reasoning_parser = MagicMock()
|
||||
result = ernie_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||
def test_llm_load(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
llm = LLM()
|
||||
llm.llm_engine = MagicMock()
|
||||
llm.default_sampling_params = MagicMock()
|
||||
llm.chat_template = "hello"
|
||||
|
||||
def mock_run_engine(req_ids, **kwargs):
|
||||
return req_ids
|
||||
|
||||
def mock_add_request(**kwargs):
|
||||
return kwargs.get("chat_template")
|
||||
|
||||
llm._run_engine = mock_run_engine
|
||||
llm._add_request = mock_add_request
|
||||
result = llm.chat(["hello"], sampling_params=SamplingParams(1))
|
||||
self.assertEqual("hello", result)
|
||||
|
||||
@patch("fastdeploy.entrypoints.llm.LLM.__init__")
|
||||
def test_llm(self, mock_class):
|
||||
mock_class.return_value = None
|
||||
llm = LLM()
|
||||
llm.llm_engine = MagicMock()
|
||||
llm.default_sampling_params = MagicMock()
|
||||
|
||||
def mock_run_engine(req_ids, **kwargs):
|
||||
return req_ids
|
||||
|
||||
def mock_add_request(**kwargs):
|
||||
return kwargs.get("chat_template")
|
||||
|
||||
llm._run_engine = mock_run_engine
|
||||
llm._add_request = mock_add_request
|
||||
result = llm.chat(["hello"], sampling_params=SamplingParams(1), chat_template="hello")
|
||||
self.assertEqual("hello", result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user