[Feature] add tool parser (#3483)

* add tool parser

* add x1 enable_thinking

* restart ci

* fix vl reasoning parser

* modify call style

* modify call style

* add offline enablethinking

* fix completion

* fix

* fix unit test

* fix unit test

* fix unit test

* fix vl reasoning parser

* fix vl reasoning parser
This commit is contained in:
luukunn
2025-08-21 17:25:44 +08:00
committed by GitHub
parent 466cbb5a99
commit 371fb3f853
14 changed files with 197 additions and 222 deletions

View File

@@ -49,6 +49,8 @@ When using FastDeploy to deploy models (including offline inference and service
| ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument | | ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument |
| ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | | ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. |
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. | | ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
| ```tool_call_parser``` | `str` | Specify the function call parser to be used for extracting function call content from the model's output. |
| ```tool_parser_plugin``` | `str` | Specify the file path of the tool parser to be registered, so as to register parsers that are not in the code repository. The code format within these parsers must adhere to the format used in the code repository. |
## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```? ## 1. Relationship between KVCache allocation, ```num_gpu_blocks_override``` and ```block_size```?

View File

@@ -47,6 +47,8 @@
| ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 | | ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 |
| ```revision``` | `str` | 自动下载模型时用于指定模型的Git版本分支名或tag | | ```revision``` | `str` | 自动下载模型时用于指定模型的Git版本分支名或tag |
| ```chat_template``` | `str` | 指定模型拼接使用的模板支持字符串与文件路径默认为None如未指定则使用模型默认模板 | | ```chat_template``` | `str` | 指定模型拼接使用的模板支持字符串与文件路径默认为None如未指定则使用模型默认模板 |
| ```tool_call_parser``` | `str` | 指定要使用的function call解析器以便从模型输出中抽取 function call内容|
| ```tool_parser_plugin``` | `str` | 指定要注册的tool parser文件路径以便注册不在代码库中的parserparser中代码格式需遵循代码库中格式|
## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系? ## 1. KVCache分配与```num_gpu_blocks_override```、```block_size```的关系?

View File

@@ -279,22 +279,20 @@ class OpenAIServingChat:
output_top_logprobs, request.logprobs, request.top_logprobs output_top_logprobs, request.logprobs, request.top_logprobs
) )
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
tool_delta_message = output["tool_delta_message"]
if tool_delta_message is None:
continue
delta_message = tool_delta_message
delta_message.reasoning_content = output.get("reasoning_content")
if delta_message.tool_calls:
tool_called = True
else:
delta_message = DeltaMessage( delta_message = DeltaMessage(
content=delta_text, content=delta_text,
reasoning_content=output.get("reasoning_content"), reasoning_content="",
prompt_token_ids=None, prompt_token_ids=None,
completion_token_ids=None, completion_token_ids=None,
tool_calls=None, tool_calls=None,
) )
if not res["finished"] and "delta_message" in output:
delta_message_output = output["delta_message"]
if delta_message_output is None:
continue
delta_message.content = delta_message_output.content or ""
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
delta_message.tool_calls = delta_message_output.tool_calls
choice = ChatCompletionResponseStreamChoice( choice = ChatCompletionResponseStreamChoice(
index=0, index=0,
@@ -475,7 +473,7 @@ class OpenAIServingChat:
max_tokens = request.max_completion_tokens or request.max_tokens max_tokens = request.max_completion_tokens or request.max_tokens
if has_no_token_limit or previous_num_tokens != max_tokens: if has_no_token_limit or previous_num_tokens != max_tokens:
choice.finish_reason = "stop" choice.finish_reason = "stop"
if self.engine_client.reasoning_parser == "ernie_x1" and output.get("finish_reason", "") == "tool_calls": if output.get("tool_call"):
choice.finish_reason = "tool_calls" choice.finish_reason = "tool_calls"
else: else:
choice.finish_reason = "length" choice.finish_reason = "length"

View File

@@ -312,7 +312,7 @@ class OpenAIServingCompletion:
output_tokens = [0] * num_choices output_tokens = [0] * num_choices
inference_start_time = [0] * num_choices inference_start_time = [0] * num_choices
first_iteration = [True] * num_choices first_iteration = [True] * num_choices
tool_called = False tool_called = [False] * num_choices
max_streaming_response_tokens = ( max_streaming_response_tokens = (
request.max_streaming_response_tokens request.max_streaming_response_tokens
if request.max_streaming_response_tokens is not None if request.max_streaming_response_tokens is not None
@@ -386,22 +386,6 @@ class OpenAIServingCompletion:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0) logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
output_tokens[idx] += 1 output_tokens[idx] += 1
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
tool_delta_message = output["tool_delta_message"]
if tool_delta_message is None:
continue
delta_message = CompletionResponseStreamChoice(
index=idx,
text=output["text"],
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
tool_calls=tool_delta_message.tool_calls,
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time,
logprobs=logprobs_res,
)
if tool_delta_message.tool_calls:
tool_called = True
else:
delta_message = CompletionResponseStreamChoice( delta_message = CompletionResponseStreamChoice(
index=idx, index=idx,
text=output["text"], text=output["text"],
@@ -410,17 +394,24 @@ class OpenAIServingCompletion:
tool_calls=None, tool_calls=None,
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None, raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
completion_tokens=output.get("raw_prediction") if request.return_token_ids else None, completion_tokens=output.get("raw_prediction") if request.return_token_ids else None,
reasoning_content=output.get("reasoning_content"), reasoning_content="",
arrival_time=arrival_time, arrival_time=arrival_time,
logprobs=logprobs_res, logprobs=logprobs_res,
) )
if not res["finished"] and "delta_message" in output:
delta_message_output = output["delta_message"]
if delta_message_output is None:
continue
delta_message.text = delta_message_output.content or ""
delta_message.reasoning_content = delta_message_output.reasoning_content or ""
delta_message.tool_calls = delta_message_output.tool_calls
choices.append(delta_message) choices.append(delta_message)
output_tokens[idx] += 1 output_tokens[idx] += 1
if res["finished"]: if res["finished"]:
choices[-1].finish_reason = self.calc_finish_reason( choices[-1].finish_reason = self.calc_finish_reason(
request.max_tokens, output_tokens[idx], output, tool_called request.max_tokens, output_tokens[idx], output, tool_called[idx]
) )
send_idx = output.get("send_idx") send_idx = output.get("send_idx")
# 只有当 send_idx 明确为 0 时才记录日志 # 只有当 send_idx 明确为 0 时才记录日志

View File

@@ -14,7 +14,6 @@
import json import json
import re import re
import traceback
import uuid import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Union from typing import Union
@@ -58,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
self.buffer: str = "" # buffer for accumulating unprocessed streaming content self.buffer: str = "" # buffer for accumulating unprocessed streaming content
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
)
if not self.model_tokenizer: if not self.model_tokenizer:
raise ValueError( raise ValueError(
@@ -163,12 +172,10 @@ class ErnieX1ToolParser(ToolParser):
} }
) )
except Exception as e: except Exception as e:
data_processor_logger.error( data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
f"Failed to parse tool call: {str(e)}, {str(traceback.format_exc())}"
)
continue continue
except Exception as e: except Exception as e:
data_processor_logger.error(f"Failed to parse tool call: {str(e)}, {str(traceback.format_exc())}") data_processor_logger.debug(f"Failed to parse tool call: {str(e)}")
continue continue
if not function_call_arr: if not function_call_arr:
@@ -214,9 +221,7 @@ class ErnieX1ToolParser(ToolParser):
) )
except Exception as e: except Exception as e:
data_processor_logger.error( data_processor_logger.error(f"Error in extracting tool call from response: {str(e)}")
f"Error in extracting tool call from response: {str(e)}, {str(traceback.format_exc())}"
)
return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output) return ExtractedToolCallInformation(tools_called=False, tool_calls=None, content=model_output)
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
@@ -229,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
request: dict, request: dict,
) -> Union[DeltaMessage, None]: ) -> Union[DeltaMessage, None]:
if self.tool_call_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# 忽略空chunk # 忽略空chunk
if len(delta_text.strip()) == 0: if len(delta_text.strip()) == 0:
return None return None
@@ -239,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
self.buffer += delta_text self.buffer += delta_text
# 处理增量中的新tool_call开始 # 处理增量中的新tool_call开始
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text: if "<tool_call>" in delta_text:
self.current_tool_id = ( self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
) )
@@ -248,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
# 增量解析逻辑
# 1. 尝试解析name字段 # 1. 尝试解析name字段
if not self.current_tool_name_sent and '"name"' in self.buffer: if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer) name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
@@ -266,7 +272,6 @@ class ErnieX1ToolParser(ToolParser):
) )
] ]
) )
print("delta name:", delta)
# 删除已处理的name部分 # 删除已处理的name部分
self.buffer = self.buffer[name_match.end() :] self.buffer = self.buffer[name_match.end() :]
self.current_tool_name_sent = True self.current_tool_name_sent = True
@@ -276,54 +281,67 @@ class ErnieX1ToolParser(ToolParser):
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer) args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match: if args_match:
args_content = args_match.group(1) args_content = args_match.group(1)
# 处理多余的大括号
open_braces = args_content.count("{")
close_braces = args_content.count("}")
if close_braces > open_braces:
args_content = args_content[: args_content.rfind("}")]
try: try:
# 增量解析arguments # 检查是否到达arguments结尾(括号完全匹配)
parsed_args = json.loads(args_content) if "}}" in args_content:
if isinstance(parsed_args, dict): # 逐个字符检查括号匹配状态
args_json = json.dumps(parsed_args, ensure_ascii=False) matched_pos = -1
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]): for i, ch in enumerate(delta_text):
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :] if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
matched_pos = i
break
if matched_pos >= 0:
# 找到匹配点清理buffer并返回
truncate_text = delta_text[: matched_pos + 1]
delta = DeltaMessage( delta = DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
index=self.current_tool_id, index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff).model_dump( function=DeltaFunctionCall(arguments=truncate_text).model_dump(
exclude_none=True exclude_none=True
), ),
) )
] ]
) )
print("delta argument:", delta) self.buffer = self.buffer[args_match.end() :]
# 删除已处理部分 return delta
processed_pos = args_match.start() + len('"arguments":') else:
self.buffer = ( # 没有完全匹配,继续累积
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :] return None
else:
# 增量返回当前可解析的部分
for ch in delta_text:
if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
)
]
) )
self.streamed_args_for_tool[self.current_tool_id] = args_json
return delta return delta
except Exception as e: except Exception as e:
data_processor_logger.error( data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
f"Partial arguments parsing: {str(e)}, {str(traceback.format_exc())}" return None
)
if "</tool_call>" in self.buffer: if "</tool_call>" in self.buffer:
end_pos = self.buffer.find("</tool_call>") end_pos = self.buffer.find("</tool_call>")
self.buffer = self.buffer[end_pos + len("</tool_call>") :] self.buffer = self.buffer[end_pos + len("</tool_call>") :]
# 完成当前工具调用处理 # 完成当前工具调用处理
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("") self.streamed_args_for_tool.append("")
return delta return delta
except Exception as e: except Exception as e:
data_processor_logger.error( data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
f"Error in streaming tool call extraction: {str(e)}, {str(traceback.format_exc())}"
)
return None return None

View File

@@ -58,7 +58,7 @@ class ErnieProcessor(BaseDataProcessor):
self.generation_config = None self.generation_config = None
self.decode_status = dict() self.decode_status = dict()
self.tool_parsers = dict() self.tool_parser_dict = dict()
self.thinking_parser_dict = dict() self.thinking_parser_dict = dict()
self._load_tokenizer() self._load_tokenizer()
data_processor_logger.info( data_processor_logger.info(
@@ -133,6 +133,8 @@ class ErnieProcessor(BaseDataProcessor):
request.set("temperature", 1) request.set("temperature", 1)
if request.get("top_p") < _SAMPLING_EPS: if request.get("top_p") < _SAMPLING_EPS:
request.set("top_p", _SAMPLING_EPS) request.set("top_p", _SAMPLING_EPS)
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
request.enable_thinking = True
data_processor_logger.info(f"Processed request {request}") data_processor_logger.info(f"Processed request {request}")
return request return request
@@ -194,6 +196,8 @@ class ErnieProcessor(BaseDataProcessor):
request["temperature"] = 1 request["temperature"] = 1
if request.get("top_p") < _SAMPLING_EPS: if request.get("top_p") < _SAMPLING_EPS:
request["top_p"] = _SAMPLING_EPS request["top_p"] = _SAMPLING_EPS
if self.reasoning_parser and self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser":
request["enable_thinking"] = True
data_processor_logger.info(f"Processed request {request}") data_processor_logger.info(f"Processed request {request}")
return request return request
@@ -309,7 +313,7 @@ class ErnieProcessor(BaseDataProcessor):
if self.reasoning_parser and ( if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser" enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
): ):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming( reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts, previous_texts,
previous_texts + delta_text, previous_texts + delta_text,
delta_text, delta_text,
@@ -317,15 +321,12 @@ class ErnieProcessor(BaseDataProcessor):
previous_token_ids + token_ids, previous_token_ids + token_ids,
token_ids, token_ids,
) )
response_dict["outputs"]["text"] = text response_dict["outputs"]["delta_message"] = reasoning_delta_message
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if self.tool_parser_obj: if self.tool_parser_obj:
if req_id not in self.tool_parsers: if req_id not in self.tool_parser_dict:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer) self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id] tool_parser = self.tool_parser_dict[req_id]
tool_call = tool_parser.extract_tool_calls_streaming( tool_call_delta_message = tool_parser.extract_tool_calls_streaming(
previous_texts, previous_texts,
previous_texts + delta_text, previous_texts + delta_text,
delta_text, delta_text,
@@ -334,12 +335,14 @@ class ErnieProcessor(BaseDataProcessor):
token_ids, token_ids,
response_dict, response_dict,
) )
response_dict["outputs"]["tool_delta_message"] = tool_call if tool_call_delta_message is None or tool_call_delta_message.tool_calls:
response_dict["outputs"]["delta_message"] = tool_call_delta_message
response_dict["outputs"]["text"] = delta_text
if is_end: if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id] del self.decode_status[req_id]
if req_id in self.tool_parsers: if req_id in self.tool_parser_dict:
del self.tool_parsers[req_id] del self.tool_parser_dict[req_id]
return response_dict return response_dict
def messages2ids(self, request_or_messages): def messages2ids(self, request_or_messages):

View File

@@ -50,7 +50,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.image_patch_id = self.ernie_processor.image_patch_id self.image_patch_id = self.ernie_processor.image_patch_id
self.spatial_conv_size = self.ernie_processor.spatial_conv_size self.spatial_conv_size = self.ernie_processor.spatial_conv_size
self.tool_parsers = dict() self.tool_parser_dict = dict()
self.decode_status = dict() self.decode_status = dict()
self._load_tokenizer() self._load_tokenizer()

View File

@@ -175,7 +175,7 @@ class DataProcessor(BaseDataProcessor):
self.generation_config = None self.generation_config = None
self.decode_status = dict() self.decode_status = dict()
self.tool_parsers = dict() self.tool_parser_dict = dict()
self.tokenizer = self._load_tokenizer() self.tokenizer = self._load_tokenizer()
data_processor_logger.info( data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \ f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -398,8 +398,10 @@ class DataProcessor(BaseDataProcessor):
token_ids = token_ids[:-1] token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id) delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text response_dict["outputs"]["raw_prediction"] = delta_text
if enable_thinking and self.reasoning_parser: if self.reasoning_parser and (
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming( enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
):
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts, previous_texts,
previous_texts + delta_text, previous_texts + delta_text,
delta_text, delta_text,
@@ -407,14 +409,11 @@ class DataProcessor(BaseDataProcessor):
previous_token_ids + token_ids, previous_token_ids + token_ids,
token_ids, token_ids,
) )
response_dict["outputs"]["text"] = text response_dict["outputs"]["delta_message"] = reasoning_delta_message
response_dict["outputs"]["reasoning_content"] = reasoning_content if self.tool_parser_obj:
else: if req_id not in self.tool_parser_dict:
response_dict["outputs"]["text"] = delta_text self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
if self.tool_parser_obj and not is_end: tool_parser = self.tool_parser_dict[req_id]
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
tool_call = tool_parser.extract_tool_calls_streaming( tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts, previous_texts,
previous_texts + delta_text, previous_texts + delta_text,
@@ -424,12 +423,14 @@ class DataProcessor(BaseDataProcessor):
token_ids, token_ids,
response_dict, response_dict,
) )
response_dict["outputs"]["tool_delta_message"] = tool_call if tool_call is None or tool_call.tool_calls:
response_dict["outputs"]["delta_message"] = tool_call
response_dict["outputs"]["text"] = delta_text
if is_end: if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}") data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id] del self.decode_status[req_id]
if req_id in self.tool_parsers: if req_id in self.tool_parser_dict:
del self.tool_parsers[req_id] del self.tool_parser_dict[req_id]
return response_dict return response_dict
def process_response_dict(self, response_dict, **kwargs): def process_response_dict(self, response_dict, **kwargs):

View File

@@ -46,6 +46,9 @@ class ErnieVLReasoningParser(ReasoningParser):
if self.think_end_token_id is None: if self.think_end_token_id is None:
raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!") raise RuntimeError("Ernie VL reasoning parser could not locate think end " "tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
@@ -65,18 +68,16 @@ class ErnieVLReasoningParser(ReasoningParser):
""" """
# Skip single special tokens # Skip single special tokens
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return "", "" return None
if self.think_end_token_id in delta_token_ids: if self.think_end_token_id in delta_token_ids:
end_index = delta_text.find(self.end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :] content = delta_text[end_index + len(self.end_token) :]
return DeltaMessage(reasoning_content=reasoning_content, content=content)
elif self.think_end_token_id in previous_token_ids: elif self.think_end_token_id in previous_token_ids:
reasoning_content = "" return DeltaMessage(content=delta_text)
content = delta_text
else: else:
reasoning_content = delta_text return DeltaMessage(reasoning_content=delta_text)
content = ""
return reasoning_content, content
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
@@ -95,7 +96,6 @@ class ErnieVLReasoningParser(ReasoningParser):
# Check if the model output contains the </think> tokens. # Check if the model output contains the </think> tokens.
if self.think_end_token not in model_output: if self.think_end_token not in model_output:
return "", model_output return "", model_output
# Extract reasoning content from the model output.
reasoning_content, _, content = model_output.partition(self.think_end_token) reasoning_content, _, content = model_output.partition(self.think_end_token)
final_content = content or "" final_content = content or ""

View File

@@ -2,9 +2,9 @@
# #
# #
from collections.abc import Sequence from collections.abc import Sequence
from typing import Tuple from typing import Tuple, Union
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
# #
@@ -47,6 +47,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
self.think_end_token_id = self.vocab.get("</think>") self.think_end_token_id = self.vocab.get("</think>")
if self.think_end_token_id is None: if self.think_end_token_id is None:
raise RuntimeError("Could not find think end token id in tokenizer vocabulary") raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.tool_call_start_token_id in input_ids
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
@@ -56,50 +60,63 @@ class ErnieX1ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
) -> tuple[str, str]: ) -> Union[DeltaMessage, None]:
""" """
根据用户需求实现的流式解析方法: 根据用户需求实现的流式解析方法:
1. 初始内容都视为思考内容 1. 初始内容都视为思考内容返回delta_text,""
2. 当遇到\n时检查后续是否是</think> 2. 当遇到\n时检查后续是否是</think>
3. 思考结束后检查是<response>还是<tool_call> 3. 如果直接遇到</think>也结束思考
4. 对于<response>内容,处理换行和结束标记 4. 思考结束后检查是<response>还是<tool_call>
5. 对于<response>内容,处理各种边界条件
""" """
# 如果还在思考阶段 if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
if not previous_text.endswith(self.think_end_token): return None
# 如果遇到\n后接</think>或直接遇到</think>,思考结束 # 思考阶段处理
if (previous_text.endswith("\n") and delta_text == self.think_end_token) or ( if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
not previous_text.endswith("\n") and delta_text == self.think_end_token # 如果遇到\n暂时不返回等待下一个delta_text
): if delta_text == "\n":
return "", "" return None
# 如果前一个是\n且当前是</think>,结束思考
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
return None
# 如果直接遇到</think>也结束思考
elif delta_text.startswith(self.think_end_token):
return None
# 否则继续返回思考内容 # 否则继续返回思考内容
return delta_text, "" return DeltaMessage(reasoning_content=delta_text)
# 思考结束后检查是tool_call还是response # 思考结束后检查是tool_call还是response
remaining_text = previous_text + delta_text remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :] after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
after_think = after_think.lstrip("\n") # 跳过think后的换行
# 跳过think后的换行
after_think = after_think.lstrip("\n")
# 处理tool_call情况 # 处理tool_call情况
if after_think.startswith(self.tool_call_start_token): if after_think.startswith(self.tool_call_start_token):
return "", "" return None
# 处理response情况 # 处理response情况
if after_think.startswith(self.response_start_token): if after_think.startswith(self.response_start_token):
response_content = after_think[len(self.response_start_token) :] # 遇到<response>标签时不立即返回
# 跳过response后的换行 if delta_text == self.response_start_token:
response_content = response_content.lstrip("\n") return None
# 遇到<response>后的换行符也不立即返回
# 检查response是否结束 elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
if response_content.endswith(self.response_end_token): return None
return "", "" # 处理回复内容中的换行符
if delta_text == "\n":
# 返回response内容(使用delta_text确保流式输出) return None
return "", delta_text # 如果前一个是\n且当前是</response>,结束回复
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
return None
# 如果直接遇到</response>也结束回复
elif delta_text == self.response_end_token:
return None
# 其他情况返回实际内容
else:
return DeltaMessage(content=delta_text)
# 默认情况不返回内容 # 默认情况不返回内容
return "", "" return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]: def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
""" """
@@ -143,66 +160,3 @@ class ErnieX1ReasoningParser(ReasoningParser):
reasoning_content = model_output reasoning_content = model_output
response_content = "" response_content = ""
return reasoning_content, response_content return reasoning_content, response_content
import unittest
from unittest.mock import MagicMock
class TestErnieX1ReasoningParser(unittest.TestCase):
def setUp(self):
self.tokenizer = MagicMock()
self.tokenizer.vocab = {
"\n</think>\n\n": 1001,
"<response>\n": 1002,
"\n</response>\n": 1003,
"<tool_call>\n": 1004,
"\n</tool_call>\n": 1005,
}
self.parser = ErnieX1ReasoningParser(self.tokenizer)
def test_streaming_with_think_and_response(self):
# 测试标准情况:\n</think>\n\n<response>\ncontent\n</response>\n
prev_text = "thinking"
delta_text = "\n</think>\n\n<response>\nanswer\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_with_think_and_tool_call(self):
# 测试tool_call情况
prev_text = "thinking"
delta_text = "\n</think>\n\n<tool_call>\ndetails\n</tool_call>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", ""))
def test_streaming_with_think_no_newline(self):
# 测试没有前置换行的情况
prev_text = "thinking"
delta_text = "</think>\n\n<response>answer</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_response_without_leading_newline(self):
# 测试response内容没有前置换行
prev_text = "thinking\n</think>\n\n"
delta_text = "<response>answer\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_response_with_middle_newline(self):
# 测试response内容中间的换行符
prev_text = "thinking\n</think>\n\n<response>\n"
delta_text = "line1\nline2\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "line1\nline2"))
def test_streaming_partial_response(self):
# 测试不完整的response流式输出
prev_text = "thinking\n</think>\n\n<response>\n"
delta_text = "partial answer"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "partial answer"))
if __name__ == "__main__":
unittest.main()

View File

@@ -48,6 +48,9 @@ class Qwen3ReasoningParser(ReasoningParser):
if self.think_end_token_id is None: if self.think_end_token_id is None:
raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!") raise RuntimeError("Qwen3 reasoning parser could not locate think end " "tokens in the tokenizer!")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
@@ -66,7 +69,7 @@ class Qwen3ReasoningParser(ReasoningParser):
- 'xyz' goes to content - 'xyz' goes to content
""" """
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]): if len(delta_token_ids) == 1 and (delta_token_ids[0] in [self.think_start_token_id, self.think_end_token_id]):
return "", "" return None
# </think> in delta # </think> in delta
if self.think_end_token_id in delta_token_ids: if self.think_end_token_id in delta_token_ids:
@@ -76,28 +79,28 @@ class Qwen3ReasoningParser(ReasoningParser):
end_index = delta_token_ids.find(self.think_end_token) end_index = delta_token_ids.find(self.think_end_token)
reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index] reasoning_content = delta_text[start_index + len(self.think_start_token) : end_index]
content = delta_text[end_index + len(self.think_end_token) :] content = delta_text[end_index + len(self.think_end_token) :]
return reasoning_content, content return DeltaMessage(reasoning_content=reasoning_content, content=content)
# <think> in previous, </think> in delta, # <think> in previous, </think> in delta,
else: else:
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token) :] content = delta_text[end_index + len(self.think_end_token) :]
content = content if content else None content = content if content else None
return reasoning_content, content return DeltaMessage(reasoning_content=reasoning_content, content=content)
# </think> in previous reasoning content continues # </think> in previous reasoning content continues
elif self.think_end_token_id in previous_token_ids: elif self.think_end_token_id in previous_token_ids:
return "", delta_text return DeltaMessage(content=delta_text)
# <think> in previous # <think> in previous
elif self.think_start_token_id in previous_token_ids: elif self.think_start_token_id in previous_token_ids:
return delta_text, "" return DeltaMessage(reasoning_content=delta_text)
# <think> in delta # <think> in delta
elif self.think_start_token_id in delta_token_ids: elif self.think_start_token_id in delta_token_ids:
start_index = delta_text.find(self.think_start_token) start_index = delta_text.find(self.think_start_token)
reasoning_content = delta_text[start_index + len(self.think_start_token) :] reasoning_content = delta_text[start_index + len(self.think_start_token) :]
content = "" content = ""
return reasoning_content, content return DeltaMessage(reasoning_content=reasoning_content, content=content)
else: else:
return delta_text, "" return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest

View File

@@ -524,7 +524,8 @@ def test_chat_with_thinking(openai_client, capsys):
stream=True, stream=True,
max_tokens=10, max_tokens=10,
) )
completion_tokens = reasoning_tokens = 1 completion_tokens = 1
reasoning_tokens = 0
total_tokens = 0 total_tokens = 0
for chunk_id, chunk in enumerate(response): for chunk_id, chunk in enumerate(response):
if chunk_id == 0: # the first chunk is an extra chunk if chunk_id == 0: # the first chunk is an extra chunk

View File

@@ -15,7 +15,8 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
self.processor.tokenizer = MagicMock() self.processor.tokenizer = MagicMock()
self.processor.tokenizer.eos_token_id = 1 self.processor.tokenizer.eos_token_id = 1
self.processor.decode_status = {} self.processor.decode_status = {}
self.processor.tool_parsers = {} self.processor.reasoning_end_dict = {}
self.processor.tool_parser_dict = {}
# 模拟 ids2tokens 方法 # 模拟 ids2tokens 方法
def mock_ids2tokens(token_ids, task_id): def mock_ids2tokens(token_ids, task_id):
@@ -31,7 +32,7 @@ class TestErnieProcessorProcessResponseDictStreaming(unittest.TestCase):
# 模拟工具解析器 # 模拟工具解析器
self.mock_tool_parser = MagicMock() self.mock_tool_parser = MagicMock()
self.mock_tool_parser.extract_tool_calls_streaming.return_value = "tool_call" self.mock_tool_parser.extract_tool_calls_streaming.return_value = None
self.mock_tool_parser_obj = MagicMock() self.mock_tool_parser_obj = MagicMock()
self.mock_tool_parser_obj.return_value = self.mock_tool_parser self.mock_tool_parser_obj.return_value = self.mock_tool_parser
self.processor.tool_parser_obj = self.mock_tool_parser_obj self.processor.tool_parser_obj = self.mock_tool_parser_obj

View File

@@ -170,6 +170,7 @@ class TestLodChatTemplate(unittest.IsolatedAsyncioTestCase):
ernie_processor.process_request_dict = mock_process_request ernie_processor.process_request_dict = mock_process_request
ernie_processor.messages2ids = mock_messages2ids ernie_processor.messages2ids = mock_messages2ids
ernie_processor.eos_token_ids = [1] ernie_processor.eos_token_ids = [1]
ernie_processor.reasoning_parser = MagicMock()
result = ernie_processor.process_request(mock_request, chat_template="hello") result = ernie_processor.process_request(mock_request, chat_template="hello")
self.assertEqual("hello", result.chat_template) self.assertEqual("hello", result.chat_template)