add tool parser

This commit is contained in:
luukunn
2025-08-13 01:06:55 +08:00
committed by luukunn
parent 132a8ef425
commit bbd50c6717
23 changed files with 1050 additions and 32 deletions

View File

@@ -148,7 +148,7 @@ class BaseDataProcessor(ABC):
class DataProcessor(BaseDataProcessor):
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
"""
Initializes the DecodeStatus object.
@@ -168,6 +168,7 @@ class DataProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
self.tool_parsers = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -180,6 +181,7 @@ class DataProcessor(BaseDataProcessor):
self.eos_token_id_len = len(self.eos_token_ids)
self.pad_token_id = self.get_pad_id()
self.reasoning_parser = None
self.tool_parser_obj = tool_parser_obj
if reasoning_parser_obj:
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
@@ -329,6 +331,12 @@ class DataProcessor(BaseDataProcessor):
else:
# 模型不支持思考,并且没单独设置enable_thinking为false
response_dict.outputs.text = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict.outputs.tool_calls = tool_call_info.tool_calls
response_dict.outputs.text = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
return response_dict
@@ -360,6 +368,12 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = full_text
if self.tool_parser_obj:
tool_parser = self.tool_parser_obj(self.tokenizer)
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
if tool_call_info.tools_called:
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
response_dict["outputs"]["text"] = tool_call_info.content
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
return response_dict
@@ -397,9 +411,25 @@ class DataProcessor(BaseDataProcessor):
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if self.tool_parser_obj and not is_end:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
previous_token_ids,
previous_token_ids + token_ids,
token_ids,
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):