mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Feature/online/vs think 20250813 (#3440)
* add stream * fix ernie_vl_reasoning_parsers * fix bug
This commit is contained in:
@@ -252,22 +252,20 @@ class OpenAIServingChat:
|
||||
logprobs_res = self._create_chat_logprobs(
|
||||
output_top_logprobs, request.logprobs, request.top_logprobs
|
||||
)
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
continue
|
||||
delta_message = tool_delta_message
|
||||
delta_message.reasoning_content = output.get("reasoning_content")
|
||||
if delta_message.tool_calls:
|
||||
tool_called = True
|
||||
if not res["finished"]:
|
||||
if "reasoning_delta_message" in output:
|
||||
delta_message = output["reasoning_delta_message"]
|
||||
elif "tool_delta_message" in output:
|
||||
delta_message = output["tool_delta_message"]
|
||||
if delta_message is not None and delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
content=delta_text,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=None,
|
||||
tool_calls=None,
|
||||
)
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
if delta_message is None:
|
||||
continue
|
||||
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
|
@@ -344,33 +344,46 @@ class OpenAIServingCompletion:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
output_tokens[idx] += 1
|
||||
if self.engine_client.data_processor.tool_parser_obj and not res["finished"]:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is None:
|
||||
continue
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=tool_delta_message.tool_calls,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
if tool_delta_message.tool_calls:
|
||||
tool_called = True
|
||||
base_kwargs = {
|
||||
"index": idx,
|
||||
"completion_token_ids": output.get("token_ids") if request.return_token_ids else None,
|
||||
"arrival_time": arrival_time,
|
||||
"logprobs": logprobs_res,
|
||||
}
|
||||
delta_message_kwargs = None
|
||||
if not res["finished"]:
|
||||
if "reasoning_delta_message" in output:
|
||||
reasoning_delta_message = output["reasoning_delta_message"]
|
||||
if reasoning_delta_message is not None:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": reasoning_delta_message.content or "",
|
||||
"reasoning_content": reasoning_delta_message.reasoning_content,
|
||||
}
|
||||
elif "tool_delta_message" in output:
|
||||
tool_delta_message = output["tool_delta_message"]
|
||||
if tool_delta_message is not None:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": tool_delta_message.content or "",
|
||||
"tool_calls": tool_delta_message.tool_calls,
|
||||
}
|
||||
if tool_delta_message.tool_calls:
|
||||
tool_called = True
|
||||
else:
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": output["text"],
|
||||
}
|
||||
else:
|
||||
delta_message = CompletionResponseStreamChoice(
|
||||
index=idx,
|
||||
text=output["text"],
|
||||
prompt_token_ids=None,
|
||||
completion_token_ids=output.get("token_ids") if request.return_token_ids else None,
|
||||
tool_calls=None,
|
||||
raw_prediction=output.get("raw_prediction") if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
delta_message_kwargs = {
|
||||
**base_kwargs,
|
||||
"text": output["text"],
|
||||
}
|
||||
|
||||
if delta_message_kwargs is None:
|
||||
continue
|
||||
delta_message = CompletionResponseStreamChoice(**delta_message_kwargs)
|
||||
|
||||
choices.append(delta_message)
|
||||
output_tokens[idx] += 1
|
||||
|
@@ -57,6 +57,16 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: list[str] = [] # map what has been streamed for each tool so far to a list
|
||||
self.buffer: str = "" # buffer for accumulating unprocessed streaming content
|
||||
self.bracket_counts: dict = {"total_l": 0, "total_r": 0} # track bracket counts in streamed deltas
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
@@ -224,6 +234,9 @@ class ErnieX1ToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: dict,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
return DeltaMessage(content=delta_text)
|
||||
# 忽略空chunk
|
||||
if len(delta_text.strip()) == 0:
|
||||
return None
|
||||
@@ -234,7 +247,7 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.buffer += delta_text
|
||||
|
||||
# 处理增量中的新tool_call开始
|
||||
if "<tool_call>" in delta_text and "<tool_call>" not in previous_text:
|
||||
if "<tool_call>" in delta_text:
|
||||
self.current_tool_id = (
|
||||
max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1
|
||||
)
|
||||
@@ -243,8 +256,6 @@ class ErnieX1ToolParser(ToolParser):
|
||||
self.streamed_args_for_tool.append("")
|
||||
data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}")
|
||||
|
||||
# 增量解析逻辑
|
||||
|
||||
# 1. 尝试解析name字段
|
||||
if not self.current_tool_name_sent and '"name"' in self.buffer:
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer)
|
||||
@@ -271,46 +282,71 @@ class ErnieX1ToolParser(ToolParser):
|
||||
args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer)
|
||||
if args_match:
|
||||
args_content = args_match.group(1)
|
||||
# 处理多余的大括号
|
||||
open_braces = args_content.count("{")
|
||||
close_braces = args_content.count("}")
|
||||
if close_braces > open_braces:
|
||||
args_content = args_content[: args_content.rfind("}")]
|
||||
print("args_content:", args_content)
|
||||
try:
|
||||
# 增量解析arguments
|
||||
parsed_args = json.loads(args_content)
|
||||
if isinstance(parsed_args, dict):
|
||||
args_json = json.dumps(parsed_args, ensure_ascii=False)
|
||||
if len(args_json) > len(self.streamed_args_for_tool[self.current_tool_id]):
|
||||
argument_diff = args_json[len(self.streamed_args_for_tool[self.current_tool_id]) :]
|
||||
# 检查是否到达arguments结尾(括号完全匹配)
|
||||
if "}}" in args_content:
|
||||
print("delta_text (partial):", delta_text)
|
||||
# 逐个字符检查括号匹配状态
|
||||
matched_pos = -1
|
||||
for i, ch in enumerate(delta_text):
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
|
||||
if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配
|
||||
matched_pos = i
|
||||
break
|
||||
|
||||
if matched_pos >= 0:
|
||||
# 找到匹配点,清理buffer并返回
|
||||
truncate_text = delta_text[: matched_pos + 1]
|
||||
print("truncate_text:", truncate_text)
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=argument_diff).model_dump(
|
||||
function=DeltaFunctionCall(arguments=truncate_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta argument:", delta)
|
||||
# 删除已处理部分
|
||||
processed_pos = args_match.start() + len('"arguments":')
|
||||
self.buffer = (
|
||||
self.buffer[:processed_pos] + self.buffer[processed_pos + len(args_json) :]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] = args_json
|
||||
self.buffer = self.buffer[args_match.end() :]
|
||||
print(delta)
|
||||
return delta
|
||||
else:
|
||||
# 没有完全匹配,继续累积
|
||||
return None
|
||||
else:
|
||||
# 增量返回当前可解析的部分
|
||||
for ch in delta_text:
|
||||
if ch == "{":
|
||||
self.bracket_counts["total_l"] += 1
|
||||
elif ch == "}":
|
||||
self.bracket_counts["total_r"] += 1
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
print("delta argument (partial):", delta)
|
||||
print(
|
||||
f"Current bracket counts - left: {self.bracket_counts['total_l']}, right: {self.bracket_counts['total_r']}"
|
||||
)
|
||||
return delta
|
||||
except Exception as e:
|
||||
data_processor_logger.debug(f"Partial arguments parsing: {str(e)}")
|
||||
|
||||
data_processor_logger.error(f"Error in streaming tool call extraction: {str(e)}")
|
||||
return None
|
||||
if "</tool_call>" in self.buffer:
|
||||
end_pos = self.buffer.find("</tool_call>")
|
||||
self.buffer = self.buffer[end_pos + len("</tool_call>") :]
|
||||
|
||||
# 完成当前工具调用处理
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
return delta
|
||||
|
@@ -50,8 +50,9 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
self._init_config()
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.thinking_parser_dict = dict()
|
||||
self.reasoning_end_dict = dict()
|
||||
self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token} \
|
||||
@@ -291,10 +292,12 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if self.reasoning_parser and (
|
||||
enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser"
|
||||
if (
|
||||
self.reasoning_parser
|
||||
and req_id not in self.reasoning_end_dict
|
||||
and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
|
||||
):
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -302,14 +305,13 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if self.tool_parser_obj:
|
||||
if req_id not in self.tool_parsers:
|
||||
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parsers[req_id]
|
||||
response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
|
||||
if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
|
||||
self.reasoning_end_dict[req_id] = True
|
||||
if self.tool_parser_obj and req_id in self.reasoning_end_dict:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -320,11 +322,14 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
response_dict,
|
||||
)
|
||||
response_dict["outputs"]["tool_delta_message"] = tool_call
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parsers:
|
||||
del self.tool_parsers[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
if req_id in self.reasoning_end_dict:
|
||||
del self.reasoning_end_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def messages2ids(self, request_or_messages):
|
||||
|
@@ -54,8 +54,9 @@ class ErnieMoEVLProcessor(ErnieProcessor):
|
||||
self.image_patch_id = self.ernie_processor.image_patch_id
|
||||
self.spatial_conv_size = self.ernie_processor.spatial_conv_size
|
||||
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.decode_status = dict()
|
||||
self.reasoning_end_dict = dict()
|
||||
self._load_tokenizer()
|
||||
self.eos_token_ids = [self.tokenizer.eos_token_id]
|
||||
self.eos_token_id_len = len(self.eos_token_ids)
|
||||
|
@@ -168,7 +168,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
self._init_config()
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tool_parsers = dict()
|
||||
self.tool_parser_dict = dict()
|
||||
self.reasoning_end_dict = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
data_processor_logger.info(
|
||||
f"tokenizer information: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
|
||||
@@ -398,8 +399,12 @@ class DataProcessor(BaseDataProcessor):
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
response_dict["outputs"]["raw_prediction"] = delta_text
|
||||
if enable_thinking and self.reasoning_parser:
|
||||
reasoning_content, text = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
if (
|
||||
self.reasoning_parser
|
||||
and req_id not in self.reasoning_end_dict
|
||||
and (enable_thinking or self.reasoning_parser.__class__.__name__ == "ErnieX1ReasoningParser")
|
||||
):
|
||||
reasoning_delta_message = self.reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
delta_text,
|
||||
@@ -407,14 +412,13 @@ class DataProcessor(BaseDataProcessor):
|
||||
previous_token_ids + token_ids,
|
||||
token_ids,
|
||||
)
|
||||
response_dict["outputs"]["text"] = text
|
||||
response_dict["outputs"]["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if self.tool_parser_obj and not is_end:
|
||||
if req_id not in self.tool_parsers:
|
||||
self.tool_parsers[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parsers[req_id]
|
||||
response_dict["outputs"]["reasoning_delta_message"] = reasoning_delta_message
|
||||
if self.reasoning_parser.is_reasoning_end(previous_token_ids + token_ids):
|
||||
self.reasoning_end_dict[req_id] = True
|
||||
if self.tool_parser_obj and req_id in self.reasoning_end_dict:
|
||||
if req_id not in self.tool_parser_dict:
|
||||
self.tool_parser_dict[req_id] = self.tool_parser_obj(self.tokenizer)
|
||||
tool_parser = self.tool_parser_dict[req_id]
|
||||
tool_call = tool_parser.extract_tool_calls_streaming(
|
||||
previous_texts,
|
||||
previous_texts + delta_text,
|
||||
@@ -425,11 +429,14 @@ class DataProcessor(BaseDataProcessor):
|
||||
response_dict,
|
||||
)
|
||||
response_dict["outputs"]["tool_delta_message"] = tool_call
|
||||
response_dict["outputs"]["text"] = delta_text
|
||||
if is_end:
|
||||
data_processor_logger.info(f"req_id:{req_id}, decode_status: {self.decode_status[req_id]}")
|
||||
del self.decode_status[req_id]
|
||||
if req_id in self.tool_parsers:
|
||||
del self.tool_parsers[req_id]
|
||||
if req_id in self.tool_parser_dict:
|
||||
del self.tool_parser_dict[req_id]
|
||||
if req_id in self.reasoning_end_dict:
|
||||
del self.reasoning_end_dict[req_id]
|
||||
return response_dict
|
||||
|
||||
def process_response_dict(self, response_dict, **kwargs):
|
||||
|
@@ -65,18 +65,16 @@ class ErnieVLReasoningParser(ReasoningParser):
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return "", ""
|
||||
return None
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.end_token) :]
|
||||
content = delta_text[end_index + len(self.end_token)]
|
||||
return DeltaMessage(reasoning_content=reasoning_content, content=content)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
reasoning_content = ""
|
||||
content = delta_text
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
reasoning_content = delta_text
|
||||
content = ""
|
||||
return reasoning_content, content
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
|
@@ -2,9 +2,9 @@
|
||||
#
|
||||
#
|
||||
from collections.abc import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Union
|
||||
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
#
|
||||
@@ -47,6 +47,10 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
self.think_end_token_id = self.vocab.get("</think>")
|
||||
if self.think_end_token_id is None:
|
||||
raise RuntimeError("Could not find think end token id in tokenizer vocabulary")
|
||||
self.tool_call_start_token_id = self.vocab.get("<tool_call>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.tool_call_start_token_id in input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@@ -56,50 +60,63 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> tuple[str, str]:
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
根据用户需求实现的流式解析方法:
|
||||
1. 初始内容都视为思考内容
|
||||
1. 初始内容都视为思考内容,返回delta_text,""
|
||||
2. 当遇到\n时检查后续是否是</think>
|
||||
3. 思考结束后检查是<response>还是<tool_call>
|
||||
4. 对于<response>内容,处理换行和结束标记
|
||||
3. 如果直接遇到</think>也结束思考
|
||||
4. 思考结束后检查是<response>还是<tool_call>
|
||||
5. 对于<response>内容,处理各种边界条件
|
||||
"""
|
||||
# 如果还在思考阶段
|
||||
if not previous_text.endswith(self.think_end_token):
|
||||
# 如果遇到\n后接</think>或直接遇到</think>,思考结束
|
||||
if (previous_text.endswith("\n") and delta_text == self.think_end_token) or (
|
||||
not previous_text.endswith("\n") and delta_text == self.think_end_token
|
||||
):
|
||||
return "", ""
|
||||
if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id:
|
||||
return None
|
||||
# 思考阶段处理
|
||||
if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text:
|
||||
# 如果遇到\n,暂时不返回,等待下一个delta_text
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</think>,结束思考
|
||||
elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 如果直接遇到</think>也结束思考
|
||||
elif delta_text.startswith(self.think_end_token):
|
||||
return None
|
||||
# 否则继续返回思考内容
|
||||
return delta_text, ""
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
# 思考结束后检查是tool_call还是response
|
||||
remaining_text = previous_text + delta_text
|
||||
after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :]
|
||||
|
||||
# 跳过think后的换行
|
||||
after_think = after_think.lstrip("\n")
|
||||
after_think = after_think.lstrip("\n") # 跳过think后的换行
|
||||
|
||||
# 处理tool_call情况
|
||||
if after_think.startswith(self.tool_call_start_token):
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
# 处理response情况
|
||||
if after_think.startswith(self.response_start_token):
|
||||
response_content = after_think[len(self.response_start_token) :]
|
||||
# 跳过response后的换行
|
||||
response_content = response_content.lstrip("\n")
|
||||
|
||||
# 检查response是否结束
|
||||
if response_content.endswith(self.response_end_token):
|
||||
return "", ""
|
||||
|
||||
# 返回response内容(使用delta_text确保流式输出)
|
||||
return "", delta_text
|
||||
# 遇到<response>标签时不立即返回
|
||||
if delta_text == self.response_start_token:
|
||||
return None
|
||||
# 遇到<response>后的换行符也不立即返回
|
||||
elif delta_text == "\n" and previous_text.endswith(self.response_start_token):
|
||||
return None
|
||||
# 处理回复内容中的换行符
|
||||
if delta_text == "\n":
|
||||
return None
|
||||
# 如果前一个是\n且当前是</response>,结束回复
|
||||
elif previous_text.endswith("\n") and delta_text == self.response_end_token:
|
||||
return None
|
||||
# 如果直接遇到</response>也结束回复
|
||||
elif delta_text == self.response_end_token:
|
||||
return None
|
||||
# 其他情况返回实际内容
|
||||
else:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
# 默认情况不返回内容
|
||||
return "", ""
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]:
|
||||
"""
|
||||
@@ -143,66 +160,3 @@ class ErnieX1ReasoningParser(ReasoningParser):
|
||||
reasoning_content = model_output
|
||||
response_content = ""
|
||||
return reasoning_content, response_content
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestErnieX1ReasoningParser(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tokenizer = MagicMock()
|
||||
self.tokenizer.vocab = {
|
||||
"\n</think>\n\n": 1001,
|
||||
"<response>\n": 1002,
|
||||
"\n</response>\n": 1003,
|
||||
"<tool_call>\n": 1004,
|
||||
"\n</tool_call>\n": 1005,
|
||||
}
|
||||
self.parser = ErnieX1ReasoningParser(self.tokenizer)
|
||||
|
||||
def test_streaming_with_think_and_response(self):
|
||||
# 测试标准情况:\n</think>\n\n<response>\ncontent\n</response>\n
|
||||
prev_text = "thinking"
|
||||
delta_text = "\n</think>\n\n<response>\nanswer\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_with_think_and_tool_call(self):
|
||||
# 测试tool_call情况
|
||||
prev_text = "thinking"
|
||||
delta_text = "\n</think>\n\n<tool_call>\ndetails\n</tool_call>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", ""))
|
||||
|
||||
def test_streaming_with_think_no_newline(self):
|
||||
# 测试没有前置换行的情况
|
||||
prev_text = "thinking"
|
||||
delta_text = "</think>\n\n<response>answer</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_response_without_leading_newline(self):
|
||||
# 测试response内容没有前置换行
|
||||
prev_text = "thinking\n</think>\n\n"
|
||||
delta_text = "<response>answer\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "answer"))
|
||||
|
||||
def test_streaming_response_with_middle_newline(self):
|
||||
# 测试response内容中间的换行符
|
||||
prev_text = "thinking\n</think>\n\n<response>\n"
|
||||
delta_text = "line1\nline2\n</response>\n"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "line1\nline2"))
|
||||
|
||||
def test_streaming_partial_response(self):
|
||||
# 测试不完整的response流式输出
|
||||
prev_text = "thinking\n</think>\n\n<response>\n"
|
||||
delta_text = "partial answer"
|
||||
result = self.parser.extract_reasoning_content_streaming(prev_text, "", delta_text, [1001], [], [])
|
||||
self.assertEqual(result, ("thinking", "partial answer"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user