add tool parser

This commit is contained in:
luukunn
2025-08-13 01:06:55 +08:00
committed by luukunn
parent ad816f20f4
commit 81092c0fe3
23 changed files with 1056 additions and 41 deletions

View File

@@ -43,7 +43,7 @@ class ErnieProcessor(BaseDataProcessor):
pad_token_id (int): 存储填充符号的token ID。
"""
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
self.model_name_or_path = model_name_or_path
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
@@ -63,6 +63,7 @@ class ErnieProcessor(BaseDataProcessor):
self.reasoning_parser = None
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tool_parser_obj = tool_parser_obj
def _init_config(self):
self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1
@@ -204,6 +205,12 @@ class ErnieProcessor(BaseDataProcessor):
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
@@ -244,12 +251,20 @@ class ErnieProcessor(BaseDataProcessor):
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
response_dict["outputs"]["raw_prediction"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
@@ -274,7 +289,9 @@ class ErnieProcessor(BaseDataProcessor):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -287,10 +304,25 @@ class ErnieProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
response_dict["outputs"]["raw_prediction"] = delta_text
if self.tool_parser_obj:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
return response_dict
def messages2ids(self, request_or_messages):

View File

@@ -34,6 +34,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
limit_mm_per_prompt=None,
mm_processor_kwargs=None,
reasoning_parser_obj=None,
tool_parser_obj=None,
):
self.use_hf_tokenizer = False
@@ -53,6 +54,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.image_patch_id = self.ernie_processor.image_patch_id
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
self.tool_parsers = dict()
self.decode_status = dict()
self._load_tokenizer()
self.eos_token_ids = [self.tokenizer.eos_token_id]
@@ -62,6 +64,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.reasoning_parser = None
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tool_parser_obj = tool_parser_obj
# Generation config
try:

View File

@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional
from fastdeploy.config import ErnieArchitectures
from fastdeploy.engine.config import ModelConfig
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.reasoning import ReasoningParserManager
@@ -48,6 +49,7 @@ class InputPreprocessor:
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
enable_mm: bool = False,
tool_parser: str = None,
) -> None:
self.model_name_or_path = model_name_or_path
@@ -55,6 +57,7 @@ class InputPreprocessor:
self.enable_mm = enable_mm
self.limit_mm_per_prompt = limit_mm_per_prompt
self.mm_processor_kwargs = mm_processor_kwargs
self.tool_parser = tool_parser
def create_processor(self):
"""
@@ -68,8 +71,11 @@ class InputPreprocessor:
DataProcessor or MultiModalRegistry.Processor (Union[DataProcessor, MultiModalRegistry.Processor]): 数据处理器。
"""
reasoning_parser_obj = None
tool_parser_obj = None
if self.reasoning_parser:
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(self.reasoning_parser)
if self.tool_parser:
tool_parser_obj = ToolParserManager.get_tool_parser(self.tool_parser)
architectures = ModelConfig({"model": self.model_name_or_path}).architectures[0]
if not self.enable_mm:
if not ErnieArchitectures.contains_ernie_arch(architectures):
@@ -78,6 +84,7 @@ class InputPreprocessor:
self.processor = DataProcessor(
model_name_or_path=self.model_name_or_path,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
)
else:
from fastdeploy.input.ernie_processor import ErnieProcessor
@@ -85,6 +92,7 @@ class InputPreprocessor:
self.processor = ErnieProcessor(
model_name_or_path=self.model_name_or_path,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
)
else:
if not architectures.startswith("Ernie4_5_VLMoeForConditionalGeneration"):
@@ -97,5 +105,6 @@ class InputPreprocessor:
limit_mm_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
reasoning_parser_obj=reasoning_parser_obj,
tool_parser_obj=tool_parser_obj,
)
return self.processor

View File

@@ -148,7 +148,7 @@ class BaseDataProcessor(ABC):
class DataProcessor(BaseDataProcessor):
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
"""
Initializes the DecodeStatus object.
@@ -168,6 +168,7 @@ class DataProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
self.tool_parsers = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -180,6 +181,7 @@ class DataProcessor(BaseDataProcessor):
self.eos_token_id_len = len(self.eos_token_ids)
self.pad_token_id = self.get_pad_id()
self.reasoning_parser = None
self.tool_parser_obj = tool_parser_obj
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
@@ -329,6 +331,12 @@ class DataProcessor(BaseDataProcessor):
else:
# 模型不支持思考,并且没单独设置enable_thinking为false
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
return response_dict
@@ -360,6 +368,12 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
return response_dict
@@ -397,9 +411,25 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if self.tool_parser_obj and not is_end:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):