mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[Feature] add tool parser (#3483)
* add tool parser * add x1 enable_thinking * restart ci * fix vl reasoning parser * modify call style * modify call style * add offline enablethinking * fix completion * fix * fix unit test * fix unit test * fix unit test * fix vl reasoning parser * fix vl reasoning parser
This commit is contained in:
@@ -49,6 +49,8 @@ When using FastDeploy to deploy models (including offline inference and service
|
||||
| ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument |
|
||||
| ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. |
|
||||
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
|
||||
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
|
||||
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
|
||||
|
||||
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?
|
||||
|
||||
|
@@ -47,6 +47,8 @@
|
||||
| ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 |
|
||||
| ```revision``` | `str` | 自动下载模型时,用于指定模型的Git版本,分支名或tag |
|
||||
| ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 |
|
||||
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器,以便从模型输出中抽取 function call内容|
|
||||
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径,以便注册不在代码库中的parser,parser中代码格式需遵循代码库中格式|
|
||||
|
||||
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?
|
||||
|
||||
|
@@ -279,22 +279,20 @@ class OpenAIServingChat:
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content="",
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
continue
|
||||
delta_message = tool_delta_message
|
||||
delta_message.reasoning_content = output.get("reasoning_content")
|
||||
if delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
delta_message.content = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -475,7 +473,7 @@ class OpenAIServingChat:
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
if has_no_token_limit or previous_num_tokens != max_tokens:
|
||||
choice.finish_reason = "stop"
|
||||
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls":
|
||||
if output.get("tool_call"):
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
@@ -312,7 +312,7 @@ class OpenAIServingCompletion:
|
||||
output_tokens = [0] * num_choices
|
||||
inference_start_time = [0] * num_choices
|
||||
first_iteration = [True] * num_choices
|
||||
tool_called = False
|
||||
tool_called = [False] * num_choices
|
||||
max_streaming_response_tokens = (
|
||||
request.max_streaming_response_tokens
|
||||
if request.max_streaming_response_tokens is not None
|
||||
@@ -386,41 +386,32 @@ class OpenAIServingCompletion:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
output_tokens[idx] += 1
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=None,
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
reasoning_content="",
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
if not res["finished"] and "delta_message" in output:
|
||||
delta_message_output = output["delta_message"]
|
||||
if delta_message_output is None:
|
||||
continue
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=tool_delta_message.tool_calls,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
if tool_delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=None,
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
delta_message.text = delta_message_output.content or ""
|
||||
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
|
||||
delta_message.tool_calls = delta_message_output.tool_calls
|
||||
|
||||
choices.append(delta_message)
|
||||
output_tokens[idx] += 1
|
||||
|
||||
if res["finished"]:
|
||||
choices[-1].finish_reason = self.calc_finish_reason(
|
||||
request.max_tokens, output_tokens[idx], output, tool_called
|
||||
request.max_tokens, output_tokens[idx], output, tool_called[idx]
|
||||
)
|
||||
send_idx = output.get("send_idx")
|
||||
# 只有当 send_idx 明确为 0 时才记录日志
|
||||
|
@@ -14,7 +14,6 @@
|
||||
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
@@ -58,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
|
||||
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
|
||||
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
@@ -163,12 +172,10 @@ class ErnieX1ToolParser(ToolParser):
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
data_processor_logger.error(
|
||||
f"Failed to parse tool call: {str(e)}, {str(traceback.format_exc())}"
|
||||
)
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
data_processor_logger.error(f"Failed to parse tool call: {str(e)}, {str(traceback.format_exc())}")
|
||||
data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
|
||||
continue
|
||||
|
||||
if not function_call_arr:
|
||||
@@ -214,9 +221,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(
|
||||
f"Error in extracting tool call from response: {str(e)}, {str(traceback.format_exc())}"
|
||||
)
|
||||
data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
|
||||
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
@@ -229,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: dict,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
# 忽略空chunk
|
||||
if len(delta_text.strip()) == 0:
|
||||
return None
|
||||
@@ -239,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.buffer += delta_text
|
||||
|
||||
# 处理增量中的新tool_call开始
|
||||
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text:
|
||||
if "<tool_call>" in delta_text:
|
||||
self.current_tool_id = (
|
||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||
)
|
||||
@@ -248,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.streamed_args_for_tool.append("")
|
||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||
|
||||
# 增量解析逻辑
|
||||
|
||||
# 1. 尝试解析name字段
|
||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||
@@ -266,7 +272,6 @@ class ErnieX1ToolParser(ToolParser):
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta name:", delta)
|
||||
# 删除已处理的name部分
|
||||
self.buffer = self.buffer[name_match.end() :]
|
||||
self.current_tool_name_sent = True
|
||||
@@ -276,54 +281,67 @@ class ErnieX1ToolParser(ToolParser):
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||
if args_match:
|
||||
args_content = args_match.group(1)
|
||||
# 处理多余的大括号
|
||||
open_braces = args_content.count("{")
|
||||
close_braces = args_content.count("}")
|
||||
if close_braces > open_braces:
|
||||
args_content = args_content[: args_content.rfind("}")]
|
||||
try:
|
||||
# 增量解析arguments
|
||||
parsed_args = json.loads(args_content)
|
||||
if isinstance(parsed_args, dict):
|
||||
args_json = json.dumps(parsed_args, ensure_ascii=False)
|
||||
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
|
||||
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
|
||||
# 检查是否到达arguments结尾(括号完全匹配)
|
||||
if "}}" in args_content:
|
||||
# 逐个字符检查括号匹配状态
|
||||
matched_pos = -1
|
||||
for i, ch in enumerate(delta_text):
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
|
||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
|
||||
matched_pos = i
|
||||
break
|
||||
|
||||
if matched_pos >= 0:
|
||||
# 找到匹配点,清理buffer并返回
|
||||
truncate_text = delta_text[: matched_pos + 1]
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
|
||||
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta argument:", delta)
|
||||
# 删除已处理部分
|
||||
processed_pos = args_match.start() + len('"arguments":')
|
||||
self.buffer = (
|
||||
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = args_json
|
||||
self.buffer = self.buffer[args_match.end() :]
|
||||
return delta
|
||||
else:
|
||||
# 没有完全匹配,继续累积
|
||||
return None
|
||||
else:
|
||||
# 增量返回当前可解析的部分
|
||||
for ch in delta_text:
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
return delta
|
||||
except Exception as e:
|
||||
data_processor_logger.error(
|
||||
f"Partial arguments parsing: {str(e)}, {str(traceback.format_exc())}"
|
||||
)
|
||||
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
if "</tool_call>" in self.buffer:
|
||||
end_pos = self.buffer.find("</tool_call>")
|
||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||
|
||||
# 完成当前工具调用处理
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
return delta
|
||||
|
||||
except Exception as e:
|
||||
data_processor_logger.error(
|
||||
f"Error in streaming tool call extraction: {str(e)}, {str(traceback.format_exc())}"
|
||||
)
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
|
@@ -58,7 +58,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
self.generation_config = None
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.thinking_parser_dict = dict()
|
||||
self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
@@ -133,6 +133,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request.set("temperature", 1)
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request.set("top_p", _SAMPLING_EPS)
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request.enable_thinking = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
return request
|
||||
|
||||
@@ -194,6 +196,8 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request["temperature"] = 1
|
||||
if request.get("top_p") < _SAMPLING_EPS:
|
||||
request["top_p"] = _SAMPLING_EPS
|
||||
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
|
||||
request["enable_thinking"] = True
|
||||
data_processor_logger.info(f"Processed request {request}")
|
||||
|
||||
return request
|
||||
@@ -309,7 +313,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -317,15 +321,12 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parsers:
|
||||
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parsers[req_id]
|
||||
tool_call = tool_parser.extract_tool_calls_streaming(
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -334,12 +335,14 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids,
|
||||
response_dict,
|
||||
)
|
||||
response_dict["outputs"]["tool_delta_message"] = tool_call
|
||||
if tool_call_delta_message is None or tool_call_delta_message.tool_calls:
|
||||
response_dict["outputs"]["delta_message"] = tool_call_delta_message
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parsers:
|
||||
del self.tool_parsers[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages):
|
||||
|
@@ -50,7 +50,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
self.image_patch_id = self.ernie_processor.image_patch_id
|
||||
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
|
||||
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.decode_status = dict()
|
||||
self._load_tokenizer()
|
||||
|
||||
|
@@ -175,7 +175,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.generation_config = None
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||
@@ -398,8 +398,10 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
):
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -407,14 +409,11 @@ class DataProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if self.tool_parser_obj and not is_end:
|
||||
if req_id not in self.tool_parsers:
|
||||
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parsers[req_id]
|
||||
response_dict["outputs"]["delta_message"] = reasoning_delta_message
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -424,12 +423,14 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids,
|
||||
response_dict,
|
||||
)
|
||||
response_dict["outputs"]["tool_delta_message"] = tool_call
|
||||
if tool_call is None or tool_call.tool_calls:
|
||||
response_dict["outputs"]["delta_message"] = tool_call
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parsers:
|
||||
del self.tool_parsers[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, **kwargs):
|
||||
|
@@ -46,6 +46,9 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -65,18 +68,16 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return "", ""
|
||||
return None
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
reasoning_content = ""
|
||||
content = delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
reasoning_content = delta_text
|
||||
content = ""
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
@@ -95,7 +96,6 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
# Check if the model output contains the </think> tokens.
|
||||
if self.think_end_token not in model_output:
|
||||
return "", model_output
|
||||
# Extract reasoning content from the model output.
|
||||
reasoning_content, _, content = model_output.partition(self.think_end_token)
|
||||
|
||||
final_content = content or ""
|
||||
|
@@ -2,9 +2,9 @@
|
||||
#
|
||||
#
|
||||
from collections.abc import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
#
|
||||
@@ -47,6 +47,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
self.think_end_token_id = self.vocab.get("</think>")
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.tool_call_start_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -56,50 +60,63 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> tuple[str, str]:
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
根据用户需求实现的流式解析方法:
|
||||
1. 初始内容都视为思考内容
|
||||
1. 初始内容都视为思考内容,返回delta_text,""
|
||||
2. 当遇到\n时检查后续是否是</think>
|
||||
3. 思考结束后检查是<response>还是<tool_call>
|
||||
4. 对于<response>内容,处理换行和结束标记
|
||||
3. 如果直接遇到</think>也结束思考
|
||||
4. 思考结束后检查是<response>还是<tool_call>
|
||||
5. 对于<response>内容,处理各种边界条件
|
||||
"""
|
||||
# 如果还在思考阶段
|
||||
if not previous_text.endswith(self.think_end_token):
|
||||
# 如果遇到\n后接</think>或直接遇到</think>,思考结束
|
||||
if (previous_text.endswith("\n") and delta_text == self.think_end_token) or (
|
||||
not previous_text.endswith("\n") and delta_text == self.think_end_token
|
||||
):
|
||||
return "", ""
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
# 思考阶段处理
|
||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
||||
# 如果遇到\n,暂时不返回,等待下一个delta_text
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</think>,结束思考
|
||||
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 如果直接遇到</think>也结束思考
|
||||
elif delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 否则继续返回思考内容
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
# 思考结束后检查是tool_call还是response
|
||||
remaining_text = previous_text + delta_text
|
||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
||||
|
||||
# 跳过think后的换行
|
||||
after_think = after_think.lstrip("\n")
|
||||
after_think = after_think.lstrip("\n") # 跳过think后的换行
|
||||
|
||||
# 处理tool_call情况
|
||||
if after_think.startswith(self.tool_call_start_token):
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
# 处理response情况
|
||||
if after_think.startswith(self.response_start_token):
|
||||
response_content = after_think[len(self.response_start_token) :]
|
||||
# 跳过response后的换行
|
||||
response_content = response_content.lstrip("\n")
|
||||
|
||||
# 检查response是否结束
|
||||
if response_content.endswith(self.response_end_token):
|
||||
return "", ""
|
||||
|
||||
# 返回response内容(使用delta_text确保流式输出)
|
||||
return "", delta_text
|
||||
# 遇到<response>标签时不立即返回
|
||||
if delta_text == self.response_start_token:
|
||||
return None
|
||||
# 遇到<response>后的换行符也不立即返回
|
||||
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
|
||||
return None
|
||||
# 处理回复内容中的换行符
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</response>,结束回复
|
||||
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
|
||||
return None
|
||||
# 如果直接遇到</response>也结束回复
|
||||
elif delta_text == self.response_end_token:
|
||||
return None
|
||||
# 其他情况返回实际内容
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# 默认情况不返回内容
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
||||
"""
|
||||
@@ -143,66 +160,3 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
reasoning_content = model_output
|
||||
response_content = ""
|
||||
return reasoning_content, response_content
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tokenizer = MagicMock()
|
||||
self.tokenizer.vocab = {
|
||||
"\n</think>\n\n": 1001,
|
||||
"<response>\n": 1002,
|
||||
"\n</response>\n": 1003,
|
||||
"<tool_call>\n": 1004,
|
||||
"\n</tool_call>\n": 1005,
|
||||
}
|
||||
self.parser = ErnieX1ReasoningParser(self.tokenizer)
|
||||
|
||||
def test_streaming_with_think_and_response(self):
|
||||
# 测试标准情况:\n</think>\n\n<response>\ncontent\n</response>\n
|
||||
prev_text = "thinking"
|
||||
delta_text = "\n</think>\n\n<response>\nanswer\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_with_think_and_tool_call(self):
|
||||
# 测试tool_call情况
|
||||
prev_text = "thinking"
|
||||
delta_text = "\n</think>\n\n<tool_call>\ndetails\n</tool_call>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", ""))
|
||||
|
||||
def test_streaming_with_think_no_newline(self):
|
||||
# 测试没有前置换行的情况
|
||||
prev_text = "thinking"
|
||||
delta_text = "</think>\n\n<response>answer</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_response_without_leading_newline(self):
|
||||
# 测试response内容没有前置换行
|
||||
prev_text = "thinking\n</think>\n\n"
|
||||
delta_text = "<response>answer\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_response_with_middle_newline(self):
|
||||
# 测试response内容中间的换行符
|
||||
prev_text = "thinking\n</think>\n\n<response>\n"
|
||||
delta_text = "line1\nline2\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "line1\nline2"))
|
||||
|
||||
def test_streaming_partial_response(self):
|
||||
# 测试不完整的response流式输出
|
||||
prev_text = "thinking\n</think>\n\n<response>\n"
|
||||
delta_text = "partial answer"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "partial answer"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@@ -48,6 +48,9 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@@ -66,7 +69,7 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
# </think> in delta
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
@@ -76,28 +79,28 @@ class Qwen3ReasoningParser(ReasoningParser):
|
||||
end_index = delta_token_ids.find(self.think_end_token)
|
||||
reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
# <think> in previous, </think> in delta,
|
||||
else:
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token) :]
|
||||
content = content if content else None
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
# </think> in previous reasoning content continues
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
return "", delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
# <think> in previous
|
||||
elif self.think_start_token_id in previous_token_ids:
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
# <think> in delta
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
reasoning_content = delta_text[start_index + len(self.think_start_token) :]
|
||||
content = ""
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
else:
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
|
@@ -524,7 +524,8 @@ def test_chat_with_thinking(openai_client, capsys):
|
||||
stream=True,
|
||||
max_tokens=10,
|
||||
)
|
||||
completion_tokens = reasoning_tokens = 1
|
||||
completion_tokens = 1
|
||||
reasoning_tokens = 0
|
||||
total_tokens = 0
|
||||
for chunk_id, chunk in enumerate(response):
|
||||
if chunk_id == 0: # the first chunk is an extra chunk
|
||||
|
@@ -15,7 +15,8 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
self.processor.tokenizer = MagicMock()
|
||||
self.processor.tokenizer.eos_token_id = 1
|
||||
self.processor.decode_status = {}
|
||||
self.processor.tool_parsers = {}
|
||||
self.processor.reasoning_end_dict = {}
|
||||
self.processor.tool_parser_dict = {}
|
||||
|
||||
# 模拟 ids2tokens 方法
|
||||
def mock_ids2tokens(token_ids, task_id):
|
||||
@@ -31,7 +32,7 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
|
||||
|
||||
# 模拟工具解析器
|
||||
self.mock_tool_parser = MagicMock()
|
||||
self.mock_tool_parser.extract_tool_calls_streaming.return_value = "tool_call"
|
||||
self.mock_tool_parser.extract_tool_calls_streaming.return_value = None
|
||||
self.mock_tool_parser_obj = MagicMock()
|
||||
self.mock_tool_parser_obj.return_value = self.mock_tool_parser
|
||||
self.processor.tool_parser_obj = self.mock_tool_parser_obj
|
||||
|
@@ -170,6 +170,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
|
||||
ernie_processor.process_request_dict = mock_process_request
|
||||
ernie_processor.messages2ids = mock_messages2ids
|
||||
ernie_processor.eos_token_ids = [1]
|
||||
ernie_processor.reasoning_parser = MagicMock()
|
||||
result = ernie_processor.process_request(mock_request, chat_template="hello")
|
||||
self.assertEqual("hello", result.chat_template)
|
||||
|
||||
|
Reference in New Issue
Block a user