add Tool Parser (#3272)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* add tool-parser

* add tool-parser

* add tool parser

* add tool parser

* fix

* add offline

* add offline

* fix

* parsers:tool&reasoning

* 修改tool parser名称·

* update

* fix reasoning-parser

* add requirements

* fix finish reason

* fix

* fix reasoning-parser

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
This commit is contained in:
luukunn
2025-08-13 01:06:55 +08:00
committed by GitHub
parent 2d1a4cacdf
commit eda83ca672
23 changed files with 1056 additions and 38 deletions

View File

@@ -42,7 +42,7 @@ class ErnieProcessor(BaseDataProcessor):
pad_token_id (int): 存储填充符号的token ID。
"""
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
self.model_name_or_path = model_name_or_path
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
@@ -58,6 +58,7 @@ class ErnieProcessor(BaseDataProcessor):
self.generation_config = None
self.decode_status = dict()
self.tool_parsers = dict()
self.thinking_parser_dict = dict()
self._load_tokenizer()
data_processor_logger.info(
@@ -71,6 +72,7 @@ class ErnieProcessor(BaseDataProcessor):
self.eos_token_id_len = len(self.eos_token_ids)
self.pad_token_id = self.get_pad_id()
self.reasoning_parser = None
self.tool_parser_obj = tool_parser_obj
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
@@ -201,6 +203,12 @@ class ErnieProcessor(BaseDataProcessor):
response_dict.outputs.reasoning_content = reasoning_content
else:
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
return None
@@ -241,12 +249,20 @@ class ErnieProcessor(BaseDataProcessor):
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
if is_end:
full_text = previous_texts + delta_text
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
response_dict["outputs"]["raw_prediction"] = full_text
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
@@ -271,7 +287,9 @@ class ErnieProcessor(BaseDataProcessor):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
if enable_thinking and self.reasoning_parser:
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
@@ -284,10 +302,25 @@ class ErnieProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
response_dict["outputs"]["raw_prediction"] = delta_text
if self.tool_parser_obj:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
return response_dict
def messages2ids(self, request_or_messages):