"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import unittest
from fastdeploy.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from fastdeploy.reasoning import ReasoningParser, ReasoningParserManager
from fastdeploy.reasoning.ernie_45_vl_thinking_reasoning_parser import (
Ernie45VLThinkingReasoningParser,
)
from fastdeploy.reasoning.ernie_vl_reasoning_parsers import ErnieVLReasoningParser
from fastdeploy.reasoning.ernie_x1_reasoning_parsers import ErnieX1ReasoningParser
class DummyTokenizer:
"""Minimal tokenizer with vocab for testing."""
def __init__(self):
self.vocab = {
"": 100,
"": 101,
"": 102,
"": 103,
"": 104,
"": 105,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class MissingTokenTokenizer:
def __init__(self):
self.vocab = {
"": 100,
"": 101,
"": 102,
"": 103,
}
def get_vocab(self):
"""Return vocab dict for testing."""
return self.vocab
class TestReasoningParser(ReasoningParser):
def is_reasoning_end(self, input_ids):
"""
Return True to simulate end of reasoning content.
"""
return True
def extract_content_ids(self, input_ids):
"""
Return input_ids directly for testing.
"""
return input_ids
def extract_reasoning_content(self, model_output, request):
"""
Used for testing non-streaming extraction.
"""
return model_output, model_output
def extract_reasoning_content_streaming(
self, previous_text, current_text, delta_text, previous_token_ids, current_token_ids, delta_token_ids
):
"""
Return None for streaming extraction; minimal implementation for testing.
"""
return None
class TestReasoningParserManager(unittest.TestCase):
"""
Unit tests for ReasoningParserManager functionality.
"""
def setUp(self):
"""
Save original registry to restore after each test.
"""
self.original_parsers = ReasoningParserManager.reasoning_parsers.copy()
def tearDown(self):
"""
Restore original registry to avoid test pollution.
"""
ReasoningParserManager.reasoning_parsers = self.original_parsers.copy()
def test_register_and_get_parser(self):
"""
Test that a parser can be registered and retrieved successfully.
Verifies normal registration and retrieval functionality.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test-parser", force=True)
parser_cls = ReasoningParserManager.get_reasoning_parser("test_parser")
self.assertIs(parser_cls, TestReasoningParser)
def test_register_duplicate_without_force_raises(self):
"""
Test that registering a parser with an existing name without force raises KeyError.
Ensures duplicate registrations are handled correctly.
"""
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=True)
with self.assertRaises(KeyError):
ReasoningParserManager.register_module(module=TestReasoningParser, name="test_parser2", force=False)
def test_register_non_subclass_raises(self):
"""
Test that registering a class not inheriting from ReasoningParser raises TypeError.
Ensures type safety for registered modules.
"""
class NotParser:
pass
with self.assertRaises(TypeError):
ReasoningParserManager.register_module(module=NotParser, name="not_parser")
def test_get_unregistered_parser_raises(self):
"""
Test that retrieving a parser that was not registered raises KeyError.
Ensures get_reasoning_parser handles unknown names correctly.
"""
with self.assertRaises(KeyError):
ReasoningParserManager.get_reasoning_parser("nonexistent_parser")
class TestErnieX1ReasoningParser(unittest.TestCase):
def setUp(self):
self.parser = ErnieX1ReasoningParser(DummyTokenizer())
self.request = ChatCompletionRequest(model="test", messages=[{"role": "user", "content": "test message"}])
self.tokenizer = DummyTokenizer()
def test_missing_token(self):
with self.assertRaises(RuntimeError) as context:
ErnieX1ReasoningParser(MissingTokenTokenizer())
exception_message = str(context.exception)
expected_message_part = "ernie x1 reasoning parser could not find the following token ids"
self.assertIn(expected_message_part, exception_message)
def test_get_model_status(self):
model_status = self.parser.get_model_status([88, 99, 104])
self.assertEqual(model_status, "response_start")
# ---- Streaming parsing ----
def test_streaming_thinking_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[200],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "a")
def test_streaming_thinking_newline_preserved(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[201],
model_status="think_start",
)
self.assertEqual(msg.reasoning_content, "\n")
def test_streaming_thinking_end_tag(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc",
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.think_end_token_id],
model_status="think_start",
)
self.assertIsNone(msg)
def test_streaming_response_content(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="h",
delta_text="h",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[202],
model_status="think_start",
)
self.assertEqual(msg.content, "h")
def test_streaming_response_newline_preserved(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hi",
current_text="hi\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[203],
model_status="think_start",
)
self.assertEqual(msg.content, "\n")
def test_streaming_response_ignore_tags(self):
self.assertIsNone(
self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="",
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab[""]],
model_status="think_start",
)
)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="\n",
delta_text="\n",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[204],
model_status="think_start",
)
self.assertIsInstance(msg, DeltaMessage)
self.assertEqual(msg.content, "\n")
self.assertIsNone(
self.parser.extract_reasoning_content_streaming(
previous_text="\n",
current_text="\n",
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab[""]],
model_status="think_start",
)
)
def test_extract_reasoning_content_streaming(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hello",
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "")
self.assertEqual(msg.reasoning_content, "")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hellohi",
delta_text="hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, "")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="hellohi",
delta_text="hellohi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, "hello")
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hellohi",
delta_text="hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="think_end",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, None)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hello",
current_text="hellohi",
delta_text="hi",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="response_start",
)
self.assertEqual(msg.content, "hi")
self.assertEqual(msg.reasoning_content, None)
msg = self.parser.extract_reasoning_content_streaming(
previous_text="hellohi",
current_text="hellohiend",
delta_text="end",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[100, 200],
model_status="response_start",
)
self.assertEqual(msg, None)
def test_streaming_tool_call(self):
msg = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="",
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[self.parser.vocab[""]],
model_status="think_start",
)
self.assertIsNone(msg)
# ---- Batch parsing ----
def test_batch_reasoning_and_response(self):
text = "abc\n\nhello\nworld"
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "hello\nworld")
def test_batch_reasoning_and_tool_call(self):
text = "abccall_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "")
def test_batch_no_thinking_tag(self):
text = "no_thinking_here"
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "no_thinking_here")
self.assertEqual(response, "")
def test_batch_response_without_end_tag(self):
text = "abcpartial response"
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc")
self.assertEqual(response, "partial response")
def test_batch_preserve_all_newlines(self):
text = "abc\n\nline1\nline2\n"
reasoning, response = self.parser.extract_reasoning_content(text, self.request, "think_start")
self.assertEqual(reasoning, "abc\n")
self.assertEqual(response, "line1\nline2\n")
def test_extract_reasoning_content(self):
reasoning_content, response_content = self.parser.extract_reasoning_content(
model_output="hello", request=self.request, model_status="response_start"
)
self.assertEqual(reasoning_content, "")
self.assertEqual(response_content, "hello")
class TestErnie45VLThinkingReasoningParser(unittest.TestCase):
def setUp(self):
self.tokenizer = DummyTokenizer()
self.parser = Ernie45VLThinkingReasoningParser(tokenizer=self.tokenizer)
self.test_request = ChatCompletionRequest(
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
)
self.parser.token_status_mapping = {
100: "think_start",
}
def test_streaming_non_reasoning(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="a",
delta_text="a",
previous_token_ids=[],
current_token_ids=[200],
delta_token_ids=[200],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "a")
self.assertIsNone(result.content)
def test_streaming_with_reasoning(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="ab",
current_text="ab",
delta_text="",
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100],
delta_token_ids=[100],
model_status="think_start",
)
self.assertIsNone(result)
def test_streaming_with_reasoning_and_content(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="ab",
current_text="ab\n\ncd",
delta_text="\n\ncd",
previous_token_ids=[200, 201],
current_token_ids=[200, 201, 100, 300, 400],
delta_token_ids=[100, 300, 400],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
self.assertEqual(result.content, "\n\ncd")
def test_streaming_with_reasoning_new_line(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc\n\n",
delta_text="\n\n",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100],
delta_token_ids=[100],
model_status="think_start",
)
self.assertIsNone(result)
def test_streaming_with_reasoning_and_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc\n\n",
delta_text="\n\n",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[100, 200, 101],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, None)
def test_streaming_with_reasoning_and_illegal_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abc\n\nhello",
delta_text="\n\nhello",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 101],
delta_token_ids=[109, 200, 101],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "\n\nhello")
def test_streaming_with_reasoning_no_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abchello\nworld",
delta_text="hello\nworld",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 200, 110],
delta_token_ids=[100, 200, 110],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
self.assertEqual(result.content, "\nworld")
def test_streaming_reasoning_previous_no_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="\nhello",
delta_text="\nhello",
previous_token_ids=[100],
current_token_ids=[100, 110, 111],
delta_token_ids=[110, 111],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
self.assertEqual(result.content, "\nhello")
def test_streaming_no_reasoning_previous_tool(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="hello",
delta_text="hello",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "hello")
def test_think_end_status_streaming(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="",
current_text="hello",
delta_text="hello",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_end",
)
self.assertIs(result, None)
result = self.parser.extract_reasoning_content_streaming(
previous_text="hello, ",
current_text="hello, hi",
delta_text="hi",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="think_end",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.content, "hi")
def test_other_status_streaming(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="hello, ",
current_text="hello, hi",
delta_text="hi",
previous_token_ids=[101],
current_token_ids=[101, 110],
delta_token_ids=[110],
model_status="tool_call_start",
)
self.assertIs(result, None)
def test_batch_no_think_end(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct response", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "direct response")
self.assertEqual(content, "")
def test_batch_no_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="direct responseabc", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "direct responseabc")
self.assertEqual(content, "")
def test_batch_think_end_normal_content(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning\nresponse", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nresponse")
def test_batch_think_end_with_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning\ntool params",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "")
def test_batch_think_end_with_illegal_tool(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning\nABC\ntool params",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nABC\ntool params")
def test_batch_think_end_content_with_newline(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning\n\n actual response",
request=self.test_request,
model_status="think_start",
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\n\n actual response")
def test_think_end_status_non_streaming(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "response")
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
reasoning, content = self.parser.extract_reasoning_content(
model_output="\n 1response", request=self.test_request, model_status="think_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "\n 1response")
def test_other_status_non_streaming(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="tool_call_start"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
reasoning, content = self.parser.extract_reasoning_content(
model_output="response", request=self.test_request, model_status="tool_call_end"
)
self.assertEqual(reasoning, "")
self.assertEqual(content, "")
def test_find_last_special_token(self):
result = self.parser.find_last_special_token([100, 110, 120, 130])
self.assertEqual(result, 100)
result = self.parser.find_last_special_token([0])
self.assertEqual(result, -1)
def test_get_model_status(self):
result = self.parser.get_model_status([100, 110, 120, 130])
self.assertEqual(result, "think_start")
result = self.parser.get_model_status([0])
self.assertEqual(result, "think_start")
class TestErnieVLReasoningParser(unittest.TestCase):
def setUp(self):
self.tokenizer = DummyTokenizer()
self.parser = ErnieVLReasoningParser(tokenizer=self.tokenizer)
self.test_request = ChatCompletionRequest(
model="ernie-test", messages=[{"role": "user", "content": "test prompt"}]
)
def test_extract_reasoning_content_stream(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abcxyz",
delta_text="xyz",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
delta_token_ids=[100, 110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertEqual(result.reasoning_content, "")
self.assertEqual(result.content, "xyz")
def test_extract_reasoning_content_stream_think_in_previous(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abcxyz",
delta_text="xyz",
previous_token_ids=[200, 201, 202, 100],
current_token_ids=[200, 201, 202, 100, 110, 120, 130],
delta_token_ids=[110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.reasoning_content)
self.assertEqual(result.content, "xyz")
def test_extract_reasoning_content_stream_no_think_token(self):
result = self.parser.extract_reasoning_content_streaming(
previous_text="abc",
current_text="abcxyz",
delta_text="xyz",
previous_token_ids=[200, 201, 202],
current_token_ids=[200, 201, 202, 110, 120, 130],
delta_token_ids=[110, 120, 130],
model_status="think_start",
)
self.assertIsInstance(result, DeltaMessage)
self.assertIsNone(result.content)
self.assertEqual(result.reasoning_content, "xyz")
def test_extract_reasoning_content(self):
reasoning, content = self.parser.extract_reasoning_content(
model_output="reasoning\nactual response", request=self.test_request, model_status="think_start"
)
self.assertEqual(reasoning, "reasoning")
self.assertEqual(content, "\nactual response")
if __name__ == "__main__":
unittest.main()