add tool parser

This commit is contained in:
luukunn
2025-08-13 01:06:55 +08:00
committed by luukunn
parent 132a8ef425
commit bbd50c6717
23 changed files with 1050 additions and 32 deletions

View File

@@ -43,13 +43,14 @@ class ErnieProcessor(BaseDataProcessor):
pad_token_id (int): 存储填充符号的token ID。
"""
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
self.model_name_or_path = model_name_or_path
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
self._init_config()
self.decode_status = dict()
self.tool_parsers = dict()
self.thinking_parser_dict = dict()
self._load_tokenizer()
data_processor_logger.info(
@@ -63,6 +64,7 @@ class ErnieProcessor(BaseDataProcessor):
self.reasoning_parser = None
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tool_parser_obj = tool_parser_obj
def _init_config(self):
self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1
@@ -204,6 +206,12 @@ class ErnieProcessor(BaseDataProcessor):
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
@@ -244,12 +252,20 @@ class ErnieProcessor(BaseDataProcessor):
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
response_dict["outputs"]["raw_prediction"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
@@ -275,7 +291,9 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -288,10 +306,25 @@ class ErnieProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
response_dict["outputs"]["raw_prediction"] = delta_text
if self.tool_parser_obj:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
return response_dict
def messages2ids(self, request_or_messages):