From dc600010de3579dd9087ef8b30c92e9fecb197f7 Mon Sep 17 00:00:00 2001 From: zhuzixuan Date: Wed, 24 Sep 2025 17:04:59 +0800 Subject: [PATCH] =?UTF-8?q?[Fix]=20X1=20reasoning=20parser=20=EF=BC=8C=20s?= =?UTF-8?q?kip=20parsing=20of=20\n=20around=20special=20tokens=20(#4241)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tool_parsers/ernie_x1_tool_parser.py | 60 ++-- .../reasoning/ernie_x1_reasoning_parsers.py | 114 ++------ .../tool_parsers/test_ernie_x1_tool_parser.py | 141 ++++++++++ tests/reasoning/test_reasoning_parser.py | 265 ++++++++++++++++++ 4 files changed, 466 insertions(+), 114 deletions(-) create mode 100644 tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py create mode 100644 tests/reasoning/test_reasoning_parser.py diff --git a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py index 9b0c7b9cb..14a784f17 100644 --- a/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py +++ b/fastdeploy/entrypoints/openai/tool_parsers/ernie_x1_tool_parser.py @@ -1,3 +1,4 @@ +""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" @@ -11,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" import json import re @@ -97,29 +99,29 @@ class ErnieX1ToolParser(ToolParser): remaining_text = model_output while True: - # 查找下一个tool_call块 + # Find the next tool_call_pos = remaining_text.find("") if tool_call_pos == -1: break - # 提取tool_call开始位置后的内容 + # Extract content after tool_content_start = tool_call_pos + len("") tool_content_end = remaining_text.find("", tool_content_start) tool_json = "" if tool_content_end == -1: - # 处理未闭合的tool_call块(截断情况) + # Processing unclosed tool_call block (truncated case) tool_json = remaining_text[tool_content_start:].strip() - remaining_text = "" # 没有更多内容需要处理 + remaining_text = "" # No more content to process else: - # 处理完整的tool_call块 + # Processing closed block tool_json = remaining_text[tool_content_start:tool_content_end].strip() remaining_text = remaining_text[tool_content_end + len("") :] if not tool_json: continue - # 处理JSON内容 + # Process tool_json tool_json = tool_json.strip() if not tool_json.startswith("{"): tool_json = "{" + tool_json @@ -127,7 +129,7 @@ class ErnieX1ToolParser(ToolParser): tool_json = tool_json + "}" try: - # 首先尝试标准JSON解析 + # Parsing strategy: First try standard json.loads try: tool_data = json.loads(tool_json) @@ -136,26 +138,26 @@ class ErnieX1ToolParser(ToolParser): { "name": tool_data["name"], "arguments": tool_data["arguments"], - "_is_complete": True, # 明确标记为完整解析 + "_is_complete": True, # Mark as complete } ) continue except json.JSONDecodeError: pass - # 标准解析失败时尝试partial_json_parser + # Try partial_json_parser when standard parsing fails from partial_json_parser.core.options import Allow try: tool_data = {} flags = Allow.ALL & ~Allow.STR - # 解析name字段 + # Parse the name field name_match = re.search(r'"name"\s*:\s*"([^"]*)"', tool_json) if name_match: tool_data["name"] = name_match.group(1) - # 解析arguments字段 + # Parse the arguments field args_match = re.search(r'"arguments"\s*:\s*(\{.*)', tool_json) if args_match: try: @@ -168,7 +170,7 @@ class ErnieX1ToolParser(ToolParser): { "name": tool_data.get("name", ""), "arguments": tool_data.get("arguments", {}), - "_is_partial": True, # 标记为部分解析 + "_is_partial": True, # Mark as partial } ) except Exception as e: @@ -183,18 +185,18 @@ class ErnieX1ToolParser(ToolParser): return ExtractedToolCallInformation(tools_called=False, content=model_output) tool_calls = [] - all_complete = True # 初始设为True,只要有一个不完整就变为False + all_complete = True # Initialize as all complete for tool_call in function_call_arr: - # 记录工具调用解析状态 + # Set flags is_complete = tool_call.get("_is_complete", False) is_partial = tool_call.get("_is_partial", False) - # 只要有一个不完整就认为整体不完整 + # If any tool call is incomplete or partial, mark all_complete as False if not is_complete or is_partial: all_complete = False - # 处理参数序列化 + # Process arguments tool_args = tool_call.get("arguments", {}) if not isinstance(tool_args, dict): tool_args = {} @@ -215,7 +217,7 @@ class ErnieX1ToolParser(ToolParser): ) ) - # 只有当所有工具调用都明确标记为complete时才返回tools_called=True + # Only return tools_called=True if all tool calls are complete return ExtractedToolCallInformation( tools_called=all_complete, tool_calls=tool_calls if tool_calls else None, content="" ) @@ -237,16 +239,16 @@ class ErnieX1ToolParser(ToolParser): if self.tool_call_start_token_id not in current_token_ids: return DeltaMessage(content=delta_text) - # 忽略空chunk + # Skip empty chunks if len(delta_text.strip()) == 0: return None try: delta = None - # 使用buffer累积delta_text内容 + # Use buffer to accumulate delta_text content self.buffer += delta_text - # 处理增量中的新tool_call开始 + # Process the buffer content if "" in delta_text: self.current_tool_id = ( max(self.current_tool_id, 0) if self.current_tool_id == -1 else self.current_tool_id + 1 @@ -256,7 +258,7 @@ class ErnieX1ToolParser(ToolParser): self.streamed_args_for_tool.append("") data_processor_logger.debug(f"New tool call started with ID: {self.current_tool_id}") - # 1. 尝试解析name字段 + # 1. Try to parse the name field if not self.current_tool_name_sent and '"name"' in self.buffer: name_match = re.search(r'"name"\s*:\s*"([^"]*)"', self.buffer) if name_match: @@ -272,19 +274,18 @@ class ErnieX1ToolParser(ToolParser): ) ] ) - # 删除已处理的name部分 + # Delete the processed name part from the buffer self.buffer = self.buffer[name_match.end() :] self.current_tool_name_sent = True return delta - # 2. 尝试解析arguments字段 + # 2. Processing arguments field if '"arguments"' in self.buffer: args_match = re.search(r'"arguments"\s*:\s*(\{.*)', self.buffer) if args_match: args_content = args_match.group(1) try: - # 检查是否到达arguments结尾(括号完全匹配) + # Check if arguments field is complete by bracket matching if "}}" in args_content: - # 逐个字符检查括号匹配状态 matched_pos = -1 for i, ch in enumerate(delta_text): if ch == "{": @@ -292,12 +293,12 @@ class ErnieX1ToolParser(ToolParser): elif ch == "}": self.bracket_counts["total_r"] += 1 - if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: # 括号完全匹配 + if self.bracket_counts["total_l"] == self.bracket_counts["total_r"]: matched_pos = i break if matched_pos >= 0: - # 找到匹配点,清理buffer并返回 + # Clean up bracket counts for next tool call truncate_text = delta_text[: matched_pos + 1] delta = DeltaMessage( tool_calls=[ @@ -312,10 +313,10 @@ class ErnieX1ToolParser(ToolParser): self.buffer = self.buffer[args_match.end() :] return delta else: - # 没有完全匹配,继续累积 + # No complete match yet return None else: - # 增量返回当前可解析的部分 + # Return partial arguments for ch in delta_text: if ch == "{": self.bracket_counts["total_l"] += 1 @@ -337,7 +338,6 @@ class ErnieX1ToolParser(ToolParser): end_pos = self.buffer.find("") self.buffer = self.buffer[end_pos + len("") :] - # 完成当前工具调用处理 self.streamed_args_for_tool.append("") return delta diff --git a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py index c75182b01..3aa23aee9 100644 --- a/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py +++ b/fastdeploy/reasoning/ernie_x1_reasoning_parsers.py @@ -1,36 +1,19 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# from collections.abc import Sequence from typing import Tuple, Union from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager -# -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - @ReasoningParserManager.register_module("ernie_x1") class ErnieX1ReasoningParser(ReasoningParser): """ Reasoning parser for ernie_x1 model with stricter boundary checking. - This implementation follows the user's proposed approach: - 1. For thinking content: waits for \n then checks for tag - 2. For response content: checks for tag first, then waits for \n - 3. Handles newlines in content more precisely + Unified rules: + - Do not strip newline before + - Do not strip newline after + - Do not strip newline before """ def __init__(self, tokenizer): @@ -49,9 +32,6 @@ class ErnieX1ReasoningParser(ReasoningParser): raise RuntimeError("Could not find think end token id in tokenizer vocabulary") self.tool_call_start_token_id = self.vocab.get("") - def is_reasoning_end(self, input_ids: list[int]) -> bool: - return self.tool_call_start_token_id in input_ids - def extract_reasoning_content_streaming( self, previous_text: str, @@ -61,102 +41,68 @@ class ErnieX1ReasoningParser(ReasoningParser): current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: - """ - 根据用户需求实现的流式解析方法: - 1. 初始内容都视为思考内容,返回delta_text,"" - 2. 当遇到\n时检查后续是否是 - 3. 如果直接遇到也结束思考 - 4. 思考结束后检查是还是 - 5. 对于内容,处理各种边界条件 - """ + # Ignore the single token if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: return None - # 思考阶段处理 + + # --- Thinking stage handling --- if not previous_text.endswith(self.think_end_token) and self.think_end_token not in previous_text: - # 如果遇到\n,暂时不返回,等待下一个delta_text - if delta_text == "\n": + # If delta is , stop thinking, do not return + if delta_text.startswith(self.think_end_token): return None - # 如果前一个是\n且当前是,结束思考 - elif previous_text.endswith("\n") and delta_text.startswith(self.think_end_token): - return None - # 如果直接遇到也结束思考 - elif delta_text.startswith(self.think_end_token): - return None - # 否则继续返回思考内容 + # Otherwise, return thinking content (keep \n as-is) return DeltaMessage(reasoning_content=delta_text) - # 思考结束后检查是tool_call还是response + # --- After thinking ends, check tool_call or response --- remaining_text = previous_text + delta_text after_think = remaining_text[remaining_text.find(self.think_end_token) + len(self.think_end_token) :] - after_think = after_think.lstrip("\n") # 跳过think后的换行 + after_think = after_think.lstrip("\n") - # 处理tool_call情况 + # Handle tool_call case: skip it if after_think.startswith(self.tool_call_start_token): return None - # 处理response情况 + # Handle response case if after_think.startswith(self.response_start_token): - # 遇到标签时不立即返回 + # Do not return when tag itself appears if delta_text == self.response_start_token: return None - # 遇到后的换行符也不立即返回 - elif delta_text == "\n" and previous_text.endswith(self.response_start_token): - return None - # 处理回复内容中的换行符 - if delta_text == "\n": - return None - # 如果前一个是\n且当前是,结束回复 - elif previous_text.endswith("\n") and delta_text == self.response_end_token: - return None - # 如果直接遇到也结束回复 + # Do not return itself elif delta_text == self.response_end_token: return None - # 其他情况返回实际内容 + # Otherwise, return response content (keep \n as-is) else: return DeltaMessage(content=delta_text) - # 默认情况不返回内容 + # Default case: return nothing return None def extract_reasoning_content(self, model_output: str, request: ChatCompletionRequest) -> Tuple[str, str]: - """ - Batch version of the enhanced parser. - Modified to preserve newlines in both reasoning and response content, - only removing the single newline before closing tags. - """ reasoning_content = "" response_content = "" think_end_pos = model_output.find(self.think_end_token) if think_end_pos != -1: - # Extract thinking content - only remove the last newline before reasoning_content = model_output[:think_end_pos] - if think_end_pos > 0 and reasoning_content[-1] == "\n": - reasoning_content = reasoning_content[:-1] remaining = model_output[think_end_pos + len(self.think_end_token) :] - # Skip newlines after - remaining = remaining.lstrip("\n") + # find or + response_pos = remaining.find(self.response_start_token) + tool_pos = remaining.find(self.tool_call_start_token) - # Check for response or tool_call - if remaining.startswith(self.response_start_token): - response_pos = len(self.response_start_token) - remaining = remaining[response_pos:].lstrip("\n") - response_end_pos = remaining.find(self.response_end_token) + # first + if response_pos != -1 and (tool_pos == -1 or response_pos < tool_pos): + # The content after the response_start position + remaining_response = remaining[response_pos + len(self.response_start_token) :] + response_end_pos = remaining_response.find(self.response_end_token) if response_end_pos != -1: - # Only strip the last newline before , not all - if response_end_pos > 0 and remaining[response_end_pos - 1] == "\n": - response_content = remaining[: response_end_pos - 1] - else: - response_content = remaining[:response_end_pos] + response_content = remaining_response[:response_end_pos] else: - # If no found, return the rest as response content - response_content = remaining - elif remaining.startswith(self.tool_call_start_token): - pass # No response content + response_content = remaining_response + # The content after the response_start position is tool_call else: - # No thinking content found, return the whole input as reasoning reasoning_content = model_output response_content = "" + return reasoning_content, response_content diff --git a/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py new file mode 100644 index 000000000..e818801d9 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_ernie_x1_tool_parser.py @@ -0,0 +1,141 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest +from unittest.mock import patch + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser import ( + ErnieX1ToolParser, +) + + +class DummyTokenizer: + """Dummy tokenizer with minimal vocab for testing""" + + def __init__(self): + self.vocab = {"": 1, "": 2} + + +class TestErnieX1ToolParser(unittest.TestCase): + def setUp(self): + class DummyTokenizer: + def __init__(self): + self.vocab = {"": 1, "": 2} + + def get_vocab(self): + return self.vocab + + self.tokenizer = DummyTokenizer() + self.parser = ErnieX1ToolParser(tokenizer=self.tokenizer) + self.dummy_request = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}]) + + # ---------------- Batch extraction tests ---------------- + + def test_extract_tool_calls_complete(self): + """Test normal extraction of complete tool_call JSON""" + output = '{"name": "get_weather", "arguments": {"location": "北京"}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertTrue(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + + def test_extract_tool_calls_partial_arguments(self): + """Test partial extraction when arguments incomplete""" + output = '{"name": "get_weather", "arguments": {"location": "北"' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + + def test_extract_tool_calls_invalid_response_before_toolcall(self): + """Test case where before is invalid""" + output = 'hello{"name": "get_weather", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + self.assertIn("", result.content) + + def test_extract_tool_calls_no_toolcall(self): + """Test when no tool_call tags are present""" + output = "no tool call here" + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + + def test_extract_tool_calls_invalid_json(self): + """Test tool_call with badly formatted JSON triggers fallback parser""" + output = '"name": "get_weather", "arguments": {' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + self.assertEqual(result.tool_calls[0].function.name, "get_weather") + + def test_extract_tool_calls_exception(self): + """Force exception to cover error branch""" + with patch( + "fastdeploy.entrypoints.openai.tool_parsers.ernie_x1_tool_parser.json.loads", side_effect=Exception("boom") + ): + output = '{"name": "get_weather", "arguments": {}}' + result = self.parser.extract_tool_calls(output, self.dummy_request) + self.assertFalse(result.tools_called) + + # ---------------- Streaming extraction tests ---------------- + + def test_streaming_no_toolcall(self): + """Streaming extraction returns normal DeltaMessage when no """ + result = self.parser.extract_tool_calls_streaming( + "", "abc", "abc", [], [], [], self.dummy_request.model_dump() + ) + self.assertIsInstance(result, DeltaMessage) + self.assertEqual(result.content, "abc") + + def test_streaming_skip_empty_chunk(self): + """Streaming extraction skips empty chunks""" + result = self.parser.extract_tool_calls_streaming( + "", "", " ", [], [1], [1], self.dummy_request.model_dump() + ) + self.assertIsNone(result) + + def test_streaming_new_toolcall_and_name(self): + """Streaming extraction detects new toolcall and extracts name""" + delta = self.parser.extract_tool_calls_streaming( + "", "", '{"name": "get_weather"', [], [1], [1], self.dummy_request.model_dump() + ) + self.assertIsNotNone(delta) + self.assertEqual(delta.tool_calls[0].function.name, "get_weather") + + def test_streaming_partial_arguments(self): + """Streaming extraction yields partial arguments deltas""" + text = '"arguments": {"location":' + delta = self.parser.extract_tool_calls_streaming( + "", "" + text, text, [], [1], [1], self.dummy_request.model_dump() + ) + self.assertIsInstance(delta, DeltaMessage) + self.assertIn("arguments", delta.tool_calls[0].function.arguments) + + def test_streaming_complete_arguments_and_end(self): + """Streaming extraction completes arguments with brackets matched and closes tool_call""" + text = '"arguments": {"location": "北京"}}' + delta = self.parser.extract_tool_calls_streaming( + "", "" + text, text, [], [1], [1], self.dummy_request.model_dump() + ) + self.assertIsInstance(delta, DeltaMessage) + # Also simulate closing tag + end_delta = self.parser.extract_tool_calls_streaming( + "", "", "", [], [2], [2], self.dummy_request.model_dump() + ) + self.assertIsNotNone(end_delta) + self.assertEqual(end_delta.content, "") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/reasoning/test_reasoning_parser.py b/tests/reasoning/test_reasoning_parser.py new file mode 100644 index 000000000..90a48c899 --- /dev/null +++ b/tests/reasoning/test_reasoning_parser.py @@ -0,0 +1,265 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import unittest + +from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage +from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager +from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser + + +class DummyTokenizer: + """Minimal tokenizer with vocab for testing.""" + + def __init__(self): + self.vocab = { + "": 100, + "": 101, + "": 102, + "": 103, + "": 104, + } + + def get_vocab(self): + """Return vocab dict for testing.""" + return self.vocab + + +class TestReasoningParser(ReasoningParser): + def is_reasoning_end(self, input_ids): + """ + Return True to simulate end of reasoning content. + """ + return True + + def extract_content_ids(self, input_ids): + """ + Return input_ids directly for testing. + """ + return input_ids + + def extract_reasoning_content(self, model_output, request): + """ + Used for testing non-streaming extraction. + """ + return model_output, model_output + + def extract_reasoning_content_streaming( + self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids + ): + """ + Return None for streaming extraction; minimal implementation for testing. + """ + return None + + +class TestReasoningParserManager(unittest.TestCase): + """ + Unit tests for ReasoningParserManager functionality. + """ + + def setUp(self): + """ + Save original registry to restore after each test. + """ + self.original_parsers = ReasoningParserManager.reasoning_parsers.copy() + + def tearDown(self): + """ + Restore original registry to avoid test pollution. + """ + ReasoningParserManager.reasoning_parsers = self.original_parsers.copy() + + def test_register_and_get_parser(self): + """ + Test that a parser can be registered and retrieved successfully. + Verifies normal registration and retrieval functionality. + """ + ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser", force=True) + parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser") + self.assertIs(parser_cls, TestReasoningParser) + + def test_register_duplicate_without_force_raises(self): + """ + Test that registering a parser with an existing name without force raises KeyError. + Ensures duplicate registrations are handled correctly. + """ + ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True) + with self.assertRaises(KeyError): + ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False) + + def test_register_non_subclass_raises(self): + """ + Test that registering a class not inheriting from ReasoningParser raises TypeError. + Ensures type safety for registered modules. + """ + + class NotParser: + pass + + with self.assertRaises(TypeError): + ReasoningParserManager.register_module(module=NotParser, name="not_parser") + + def test_get_unregistered_parser_raises(self): + """ + Test that retrieving a parser that was not registered raises KeyError. + Ensures get_reasoning_parser handles unknown names correctly. + """ + with self.assertRaises(KeyError): + ReasoningParserManager.get_reasoning_parser("nonexistent_parser") + + +class TestErnieX1ReasoningParser(unittest.TestCase): + def setUp(self): + self.parser = ErnieX1ReasoningParser(DummyTokenizer()) + self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}]) + self.tokenizer = DummyTokenizer() + + # ---- Streaming parsing ---- + def test_streaming_thinking_content(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="a", + delta_text="a", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[200], + ) + self.assertEqual(msg.reasoning_content, "a") + + def test_streaming_thinking_newline_preserved(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="abc", + current_text="abc\n", + delta_text="\n", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[201], + ) + self.assertEqual(msg.reasoning_content, "\n") + + def test_streaming_thinking_end_tag(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="abc", + current_text="abc", + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[self.parser.think_end_token_id], + ) + self.assertIsNone(msg) + + def test_streaming_response_content(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="h", + delta_text="h", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[202], + ) + self.assertEqual(msg.content, "h") + + def test_streaming_response_newline_preserved(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="hi", + current_text="hi\n", + delta_text="\n", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[203], + ) + self.assertEqual(msg.content, "\n") + + def test_streaming_response_ignore_tags(self): + self.assertIsNone( + self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="", + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[self.parser.vocab[""]], + ) + ) + + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="\n", + delta_text="\n", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[204], + ) + self.assertIsInstance(msg, DeltaMessage) + self.assertEqual(msg.content, "\n") + + self.assertIsNone( + self.parser.extract_reasoning_content_streaming( + previous_text="\n", + current_text="\n", + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[self.parser.vocab[""]], + ) + ) + + def test_streaming_tool_call(self): + msg = self.parser.extract_reasoning_content_streaming( + previous_text="", + current_text="", + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[self.parser.vocab[""]], + ) + self.assertIsNone(msg) + + # ---- Batch parsing ---- + def test_batch_reasoning_and_response(self): + text = "abc\n\nhello\nworld" + reasoning, response = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc\n") + self.assertEqual(response, "hello\nworld") + + def test_batch_reasoning_and_tool_call(self): + text = "abccall_here" + reasoning, response = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertEqual(response, "") + + def test_batch_no_thinking_tag(self): + text = "no_thinking_here" + reasoning, response = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "no_thinking_here") + self.assertEqual(response, "") + + def test_batch_response_without_end_tag(self): + text = "abcpartial response" + reasoning, response = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc") + self.assertEqual(response, "partial response") + + def test_batch_preserve_all_newlines(self): + text = "abc\n\nline1\nline2\n" + reasoning, response = self.parser.extract_reasoning_content(text, self.request) + self.assertEqual(reasoning, "abc\n") + self.assertEqual(response, "line1\nline2\n") + + +if __name__ == "__main__": + unittest.main()