mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] add tool parser (#3518)
* [Feature] Pass through the `chat_template_kwargs` to the data processing module (#3421)
* fix chat_template_args
* fix args
* add offline
* add offline
* fix
* fix
* fix default enable_thinking value
* fix default enable_thinking value
* modify condition
* Revert "modify condition"
This reverts commit 26430bdeb1
.
* fix unit test
* add Tool Parser (#3272)
* add tool-parser
* add tool-parser
* add tool parser
* add tool parser
* fix
* add offline
* add offline
* fix
* parsers:tool&reasoning
* 修改tool parser名称·
* update
* fix reasoning-parser
* add requirements
* fix finish reason
* fix
* fix reasoning-parser
* fix
* fix
* fix
* fix
* fix
---------
Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
* [Feature] add tool parser (#3483)
* add tool parser
* add x1 enable_thinking
* restart ci
* fix vl reasoning parser
* modify call style
* modify call style
* add offline enablethinking
* fix completion
* fix
* fix unit test
* fix unit test
* fix unit test
* fix vl reasoning parser
* fix vl reasoning parser
* fix unit test
---------
Co-authored-by: zhuzixuan <zhuzixuan@baidu.com>
This commit is contained in:
@@ -43,13 +43,14 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
pad_token_id (int): 存储填充符号的token ID。
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None):
|
||||
def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_obj=None):
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
data_processor_logger.info(f"model_name_or_path: {model_name_or_path}")
|
||||
self._init_config()
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.thinking_parser_dict = dict()
|
||||
self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
@@ -61,6 +62,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
self.pad_token_id = self.get_pad_id()
|
||||
self.reasoning_parser = None
|
||||
self.tool_parser_obj = tool_parser_obj
|
||||
if reasoning_parser_obj:
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
|
||||
@@ -133,6 +135,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request.enable_thinking = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
return request
|
||||
|
||||
@@ -194,6 +198,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request["enable_thinking"] = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
@@ -221,6 +227,12 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
response_dict.outputs.reasoning_content = reasoning_content
|
||||
else:
|
||||
response_dict.outputs.text = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict.outputs.tool_calls = tool_call_info.tool_calls
|
||||
response_dict.outputs.text = tool_call_info.content
|
||||
data_processor_logger.info(f"req_id:{req_id}, token)ids: {token_ids}")
|
||||
if response_dict.outputs.text == "" and response_dict.outputs.reasoning_content == "":
|
||||
return None
|
||||
@@ -261,12 +273,20 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
if is_end:
|
||||
full_text = previous_texts + delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content(full_text, response_dict)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = full_text
|
||||
if self.tool_parser_obj:
|
||||
tool_parser = self.tool_parser_obj(self.tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(full_text, response_dict)
|
||||
if tool_call_info.tools_called:
|
||||
response_dict["outputs"]["tool_call"] = tool_call_info.tool_calls
|
||||
response_dict["outputs"]["text"] = tool_call_info.content
|
||||
response_dict["outputs"]["raw_prediction"] = full_text
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
@@ -292,8 +312,10 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -301,14 +323,28 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
response_dict,
|
||||
)
|
||||
if tool_call_delta_message is None or tool_call_delta_message.tool_calls:
|
||||
response_dict["outputs"]["delta_message"] = tool_call_delta_message
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages):
|
||||
|
Reference in New Issue
Block a user