Feature/online/vs think 20250813 (#3440)

* add stream

* fix ernie_vl_reasoning_parsers

* fix bug
This commit is contained in:
luukunn
2025-08-15 18:33:58 +08:00
committed by GitHub
parent 33abfddd9b
commit edf1ca07af
8 changed files with 206 additions and 194 deletions

View File

@@ -252,22 +252,20 @@ class OpenAIServingChat:
logprobs_res = self._create_chat_logprobs(
output_top_logprobs, request.logprobs, request.top_logprobs
)
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
tool_delta_message = output["tool_delta_message"]
if tool_delta_message is None:
continue
delta_message = tool_delta_message
delta_message.reasoning_content = output.get("reasoning_content")
if delta_message.tool_calls:
tool_called = True
if not res["finished"]:
if "reasoning_delta_message" in output:
delta_message = output["reasoning_delta_message"]
elif "tool_delta_message" in output:
delta_message = output["tool_delta_message"]
if delta_message is not None and delta_message.tool_calls:
tool_called = True
else:
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = DeltaMessage(
content=delta_text,
reasoning_content=output.get("reasoning_content"),
prompt_token_ids=None,
completion_token_ids=None,
tool_calls=None,
)
delta_message = DeltaMessage(content=delta_text)
if delta_message is None:
continue
choice = ChatCompletionResponseStreamChoice(
index=0,

View File

@@ -344,33 +344,46 @@ class OpenAIServingCompletion:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
output_tokens[idx] += 1
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
tool_delta_message = output["tool_delta_message"]
if tool_delta_message is None:
continue
delta_message = CompletionResponseStreamChoice(
index=idx,
text=output["text"],
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
tool_calls=tool_delta_message.tool_calls,
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time,
logprobs=logprobs_res,
)
if tool_delta_message.tool_calls:
tool_called = True
base_kwargs = {
"index": idx,
"completion_token_ids": output.get("token_ids") if request.return_token_ids else None,
"arrival_time": arrival_time,
"logprobs": logprobs_res,
}
delta_message_kwargs = None
if not res["finished"]:
if "reasoning_delta_message" in output:
reasoning_delta_message = output["reasoning_delta_message"]
if reasoning_delta_message is not None:
delta_message_kwargs = {
**base_kwargs,
"text": reasoning_delta_message.content or "",
"reasoning_content": reasoning_delta_message.reasoning_content,
}
elif "tool_delta_message" in output:
tool_delta_message = output["tool_delta_message"]
if tool_delta_message is not None:
delta_message_kwargs = {
**base_kwargs,
"text": tool_delta_message.content or "",
"tool_calls": tool_delta_message.tool_calls,
}
if tool_delta_message.tool_calls:
tool_called = True
else:
delta_message_kwargs = {
**base_kwargs,
"text": output["text"],
}
else:
delta_message = CompletionResponseStreamChoice(
index=idx,
text=output["text"],
prompt_token_ids=None,
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
tool_calls=None,
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time,
logprobs=logprobs_res,
)
delta_message_kwargs = {
**base_kwargs,
"text": output["text"],
}
if delta_message_kwargs is None:
continue
delta_message = CompletionResponseStreamChoice(**delta_message_kwargs)
choices.append(delta_message)
output_tokens[idx] += 1

View File

@@ -57,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
)
if not self.model_tokenizer:
raise ValueError(
@@ -224,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: dict,
) -> Union[DeltaMessage, None]:
if self.tool_call_start_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# 忽略空chunk
if len(delta_text.strip()) == 0:
return None
@@ -234,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
self.buffer += delta_text
# 处理增量中的新tool_call开始
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text:
if "<tool_call>" in delta_text:
self.current_tool_id = (
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
)
@@ -243,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
self.streamed_args_for_tool.append("")
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
# 增量解析逻辑
# 1. 尝试解析name字段
if not self.current_tool_name_sent and '"name"' in self.buffer:
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
@@ -271,46 +282,71 @@ class ErnieX1ToolParser(ToolParser):
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
if args_match:
args_content = args_match.group(1)
# 处理多余的大括号
open_braces = args_content.count("{")
close_braces = args_content.count("}")
if close_braces > open_braces:
args_content = args_content[: args_content.rfind("}")]
print("args_content:", args_content)
try:
# 增量解析arguments
parsed_args = json.loads(args_content)
if isinstance(parsed_args, dict):
args_json = json.dumps(parsed_args, ensure_ascii=False)
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
# 检查是否到达arguments结尾(括号完全匹配)
if "}}" in args_content:
print("delta_text (partial):", delta_text)
# 逐个字符检查括号匹配状态
matched_pos = -1
for i, ch in enumerate(delta_text):
if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
matched_pos = i
break
if matched_pos >= 0:
# 找到匹配点清理buffer并返回
truncate_text = delta_text[: matched_pos + 1]
print("truncate_text:", truncate_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
exclude_none=True
),
)
]
)
print("delta argument:", delta)
# 删除已处理部分
processed_pos = args_match.start() + len('"arguments":')
self.buffer = (
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
)
self.streamed_args_for_tool[self.current_tool_id] = args_json
self.buffer = self.buffer[args_match.end() :]
print(delta)
return delta
else:
# 没有完全匹配,继续累积
return None
else:
# 增量返回当前可解析的部分
for ch in delta_text:
if ch == "{":
self.bracket_counts["total_l"] += 1
elif ch == "}":
self.bracket_counts["total_r"] += 1
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
)
]
)
print("delta argument (partial):", delta)
print(
f"Current bracket counts - left: {self.bracket_counts['total_l']}, right: {self.bracket_counts['total_r']}"
)
return delta
except Exception as e:
data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
return None
if "</tool_call>" in self.buffer:
end_pos = self.buffer.find("</tool_call>")
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
# 完成当前工具调用处理
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
return delta

View File

@@ -50,8 +50,9 @@ class ErnieProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
self.tool_parsers = dict()
self.tool_parser_dict = dict()
self.thinking_parser_dict = dict()
self.reasoning_end_dict = dict()
self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
@@ -291,10 +292,12 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text
if self.reasoning_parser and (
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
if (
self.reasoning_parser
and req_id not in self.reasoning_end_dict
and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
):
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
@@ -302,14 +305,13 @@ class ErnieProcessor(BaseDataProcessor):
previous_token_ids + token_ids,
token_ids,
)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if self.tool_parser_obj:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
self.reasoning_end_dict[req_id] = True
if self.tool_parser_obj and req_id in self.reasoning_end_dict:
if req_id not in self.tool_parser_dict:
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parser_dict[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
@@ -320,11 +322,14 @@ class ErnieProcessor(BaseDataProcessor):
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
response_dict["outputs"]["text"] = delta_text
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
if req_id in self.tool_parser_dict:
del self.tool_parser_dict[req_id]
if req_id in self.reasoning_end_dict:
del self.reasoning_end_dict[req_id]
return response_dict
def messages2ids(self, request_or_messages):

View File

@@ -54,8 +54,9 @@ class ErnieMoEVLProcessor(ErnieProcessor):
self.image_patch_id = self.ernie_processor.image_patch_id
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
self.tool_parsers = dict()
self.tool_parser_dict = dict()
self.decode_status = dict()
self.reasoning_end_dict = dict()
self._load_tokenizer()
self.eos_token_ids = [self.tokenizer.eos_token_id]
self.eos_token_id_len = len(self.eos_token_ids)

View File

@@ -168,7 +168,8 @@ class DataProcessor(BaseDataProcessor):
self._init_config()
self.decode_status = dict()
self.tool_parsers = dict()
self.tool_parser_dict = dict()
self.reasoning_end_dict = dict()
self.tokenizer = self._load_tokenizer()
data_processor_logger.info(
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
@@ -398,8 +399,12 @@ class DataProcessor(BaseDataProcessor):
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
response_dict["outputs"]["raw_prediction"] = delta_text
if enable_thinking and self.reasoning_parser:
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
if (
self.reasoning_parser
and req_id not in self.reasoning_end_dict
and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
):
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
previous_texts,
previous_texts + delta_text,
delta_text,
@@ -407,14 +412,13 @@ class DataProcessor(BaseDataProcessor):
previous_token_ids + token_ids,
token_ids,
)
response_dict["outputs"]["text"] = text
response_dict["outputs"]["reasoning_content"] = reasoning_content
else:
response_dict["outputs"]["text"] = delta_text
if self.tool_parser_obj and not is_end:
if req_id not in self.tool_parsers:
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parsers[req_id]
response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
self.reasoning_end_dict[req_id] = True
if self.tool_parser_obj and req_id in self.reasoning_end_dict:
if req_id not in self.tool_parser_dict:
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
tool_parser = self.tool_parser_dict[req_id]
tool_call = tool_parser.extract_tool_calls_streaming(
previous_texts,
previous_texts + delta_text,
@@ -425,11 +429,14 @@ class DataProcessor(BaseDataProcessor):
response_dict,
)
response_dict["outputs"]["tool_delta_message"] = tool_call
response_dict["outputs"]["text"] = delta_text
if is_end:
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
del self.decode_status[req_id]
if req_id in self.tool_parsers:
del self.tool_parsers[req_id]
if req_id in self.tool_parser_dict:
del self.tool_parser_dict[req_id]
if req_id in self.reasoning_end_dict:
del self.reasoning_end_dict[req_id]
return response_dict
def process_response_dict(self, response_dict, **kwargs):

View File

@@ -65,18 +65,16 @@ class ErnieVLReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return "", ""
return None
if self.think_end_token_id in delta_token_ids:
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.end_token) :]
content = delta_text[end_index + len(self.end_token)]
return DeltaMessage(reasoning_content=reasoning_content, content=content)
elif self.think_end_token_id in previous_token_ids:
reasoning_content = ""
content = delta_text
return DeltaMessage(content=delta_text)
else:
reasoning_content = delta_text
content = ""
return reasoning_content, content
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest

View File

@@ -2,9 +2,9 @@
#
#
from collections.abc import Sequence
from typing import Tuple
from typing import Tuple, Union
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
#
@@ -47,6 +47,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
self.think_end_token_id = self.vocab.get("</think>")
if self.think_end_token_id is None:
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.tool_call_start_token_id in input_ids
def extract_reasoning_content_streaming(
self,
@@ -56,50 +60,63 @@ class ErnieX1ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> tuple[str, str]:
) -> Union[DeltaMessage, None]:
"""
根据用户需求实现的流式解析方法:
1. 初始内容都视为思考内容
1. 初始内容都视为思考内容返回delta_text,""
2. 当遇到\n时检查后续是否是</think>
3. 思考结束后检查是<response>还是<tool_call>
4. 对于<response>内容,处理换行和结束标记
3. 如果直接遇到</think>也结束思考
4. 思考结束后检查是<response>还是<tool_call>
5. 对于<response>内容,处理各种边界条件
"""
# 如果还在思考阶段
if not previous_text.endswith(self.think_end_token):
# 如果遇到\n后接</think>或直接遇到</think>,思考结束
if (previous_text.endswith("\n") and delta_text == self.think_end_token) or (
not previous_text.endswith("\n") and delta_text == self.think_end_token
):
return "", ""
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
# 思考阶段处理
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
# 如果遇到\n暂时不返回等待下一个delta_text
if delta_text == "\n":
return None
# 如果前一个是\n且当前是</think>,结束思考
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
return None
# 如果直接遇到</think>也结束思考
elif delta_text.startswith(self.think_end_token):
return None
# 否则继续返回思考内容
return delta_text, ""
return DeltaMessage(reasoning_content=delta_text)
# 思考结束后检查是tool_call还是response
remaining_text = previous_text + delta_text
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
# 跳过think后的换行
after_think = after_think.lstrip("\n")
after_think = after_think.lstrip("\n") # 跳过think后的换行
# 处理tool_call情况
if after_think.startswith(self.tool_call_start_token):
return "", ""
return None
# 处理response情况
if after_think.startswith(self.response_start_token):
response_content = after_think[len(self.response_start_token) :]
# 跳过response后的换行
response_content = response_content.lstrip("\n")
# 检查response是否结束
if response_content.endswith(self.response_end_token):
return "", ""
# 返回response内容(使用delta_text确保流式输出)
return "", delta_text
# 遇到<response>标签时不立即返回
if delta_text == self.response_start_token:
return None
# 遇到<response>后的换行符也不立即返回
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
return None
# 处理回复内容中的换行符
if delta_text == "\n":
return None
# 如果前一个是\n且当前是</response>,结束回复
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
return None
# 如果直接遇到</response>也结束回复
elif delta_text == self.response_end_token:
return None
# 其他情况返回实际内容
else:
return DeltaMessage(content=delta_text)
# 默认情况不返回内容
return "", ""
return None
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
"""
@@ -143,66 +160,3 @@ class ErnieX1ReasoningParser(ReasoningParser):
reasoning_content = model_output
response_content = ""
return reasoning_content, response_content
import unittest
from unittest.mock import MagicMock
class TestErnieX1ReasoningParser(unittest.TestCase):
def setUp(self):
self.tokenizer = MagicMock()
self.tokenizer.vocab = {
"\n</think>\n\n": 1001,
"<response>\n": 1002,
"\n</response>\n": 1003,
"<tool_call>\n": 1004,
"\n</tool_call>\n": 1005,
}
self.parser = ErnieX1ReasoningParser(self.tokenizer)
def test_streaming_with_think_and_response(self):
# 测试标准情况:\n</think>\n\n<response>\ncontent\n</response>\n
prev_text = "thinking"
delta_text = "\n</think>\n\n<response>\nanswer\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_with_think_and_tool_call(self):
# 测试tool_call情况
prev_text = "thinking"
delta_text = "\n</think>\n\n<tool_call>\ndetails\n</tool_call>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", ""))
def test_streaming_with_think_no_newline(self):
# 测试没有前置换行的情况
prev_text = "thinking"
delta_text = "</think>\n\n<response>answer</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_response_without_leading_newline(self):
# 测试response内容没有前置换行
prev_text = "thinking\n</think>\n\n"
delta_text = "<response>answer\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "answer"))
def test_streaming_response_with_middle_newline(self):
# 测试response内容中间的换行符
prev_text = "thinking\n</think>\n\n<response>\n"
delta_text = "line1\nline2\n</response>\n"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "line1\nline2"))
def test_streaming_partial_response(self):
# 测试不完整的response流式输出
prev_text = "thinking\n</think>\n\n<response>\n"
delta_text = "partial answer"
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
self.assertEqual(result, ("thinking", "partial answer"))
if __name__ == "__main__":
unittest.main()