mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-10 19:10:20 +08:00
Feature/online/vs think 20250813 (#3440)
* add stream * fix ernie_vl_reasoning_parsers * fix bug
This commit is contained in:
@@ -252,22 +252,20 @@ class OpenAIServingChat:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
continue
|
||||
delta_message = tool_delta_message
|
||||
delta_message.reasoning_content = output.get("reasoning_content")
|
||||
if delta_message.tool_calls:
|
||||
tool_called = True
|
||||
if not res["finished"]:
|
||||
if "reasoning_delta_message" in output:
|
||||
delta_message = output["reasoning_delta_message"]
|
||||
elif "tool_delta_message" in output:
|
||||
delta_message = output["tool_delta_message"]
|
||||
if delta_message is not None and delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
|
@@ -344,33 +344,46 @@ class OpenAIServingCompletion:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
output_tokens[idx] += 1
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
continue
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=tool_delta_message.tool_calls,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
if tool_delta_message.tool_calls:
|
||||
tool_called = True
|
||||
base_kwargs = {
|
||||
"index": idx,
|
||||
"completion_token_ids": output.get("token_ids") if request.return_token_ids else None,
|
||||
"arrival_time": arrival_time,
|
||||
"logprobs": logprobs_res,
|
||||
}
|
||||
delta_message_kwargs = None
|
||||
if not res["finished"]:
|
||||
if "reasoning_delta_message" in output:
|
||||
reasoning_delta_message = output["reasoning_delta_message"]
|
||||
if reasoning_delta_message is not None:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": reasoning_delta_message.content or "",
|
||||
"reasoning_content": reasoning_delta_message.reasoning_content,
|
||||
}
|
||||
elif "tool_delta_message" in output:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is not None:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": tool_delta_message.content or "",
|
||||
"tool_calls": tool_delta_message.tool_calls,
|
||||
}
|
||||
if tool_delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": output["text"],
|
||||
}
|
||||
else:
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=None,
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": output["text"],
|
||||
}
|
||||
|
||||
if delta_message_kwargs is None:
|
||||
continue
|
||||
delta_message = CompletionResponseStreamChoice(**delta_message_kwargs)
|
||||
|
||||
choices.append(delta_message)
|
||||
output_tokens[idx] += 1
|
||||
|
@@ -57,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
|
||||
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
|
||||
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
@@ -224,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: dict,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
# 忽略空chunk
|
||||
if len(delta_text.strip()) == 0:
|
||||
return None
|
||||
@@ -234,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.buffer += delta_text
|
||||
|
||||
# 处理增量中的新tool_call开始
|
||||
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text:
|
||||
if "<tool_call>" in delta_text:
|
||||
self.current_tool_id = (
|
||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||
)
|
||||
@@ -243,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.streamed_args_for_tool.append("")
|
||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||
|
||||
# 增量解析逻辑
|
||||
|
||||
# 1. 尝试解析name字段
|
||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||
@@ -271,46 +282,71 @@ class ErnieX1ToolParser(ToolParser):
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||
if args_match:
|
||||
args_content = args_match.group(1)
|
||||
# 处理多余的大括号
|
||||
open_braces = args_content.count("{")
|
||||
close_braces = args_content.count("}")
|
||||
if close_braces > open_braces:
|
||||
args_content = args_content[: args_content.rfind("}")]
|
||||
print("args_content:", args_content)
|
||||
try:
|
||||
# 增量解析arguments
|
||||
parsed_args = json.loads(args_content)
|
||||
if isinstance(parsed_args, dict):
|
||||
args_json = json.dumps(parsed_args, ensure_ascii=False)
|
||||
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
|
||||
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
|
||||
# 检查是否到达arguments结尾(括号完全匹配)
|
||||
if "}}" in args_content:
|
||||
print("delta_text (partial):", delta_text)
|
||||
# 逐个字符检查括号匹配状态
|
||||
matched_pos = -1
|
||||
for i, ch in enumerate(delta_text):
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
|
||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
|
||||
matched_pos = i
|
||||
break
|
||||
|
||||
if matched_pos >= 0:
|
||||
# 找到匹配点,清理buffer并返回
|
||||
truncate_text = delta_text[: matched_pos + 1]
|
||||
print("truncate_text:", truncate_text)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
|
||||
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta argument:", delta)
|
||||
# 删除已处理部分
|
||||
processed_pos = args_match.start() + len('"arguments":')
|
||||
self.buffer = (
|
||||
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = args_json
|
||||
self.buffer = self.buffer[args_match.end() :]
|
||||
print(delta)
|
||||
return delta
|
||||
else:
|
||||
# 没有完全匹配,继续累积
|
||||
return None
|
||||
else:
|
||||
# 增量返回当前可解析的部分
|
||||
for ch in delta_text:
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta argument (partial):", delta)
|
||||
print(
|
||||
f"Current bracket counts - left: {self.bracket_counts['total_l']}, right: {self.bracket_counts['total_r']}"
|
||||
)
|
||||
return delta
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
|
||||
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
if "</tool_call>" in self.buffer:
|
||||
end_pos = self.buffer.find("</tool_call>")
|
||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||
|
||||
# 完成当前工具调用处理
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
return delta
|
||||
|
Reference in New Issue
Block a user