mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
356 lines
14 KiB
Python
356 lines
14 KiB
Python
"""
|
||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""
|
||
|
||
"""
|
||
测试GuidedDecoding类的单元测试
|
||
"""
|
||
|
||
import sys
|
||
import unittest
|
||
from concurrent.futures import Future
|
||
from unittest.mock import MagicMock, Mock, patch
|
||
|
||
import paddle
|
||
|
||
mock_torch = MagicMock()
|
||
mock_xgrammar = MagicMock()
|
||
sys.modules["torch"] = mock_torch
|
||
sys.modules["xgrammar"] = mock_xgrammar
|
||
|
||
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||
from fastdeploy.model_executor.layers.sample.sampler import GuidedDecoding
|
||
from fastdeploy.reasoning import ReasoningParser
|
||
|
||
|
||
class TestGuidedDecoding(unittest.TestCase):
|
||
"""Test cases for GuidedDecoding class."""
|
||
|
||
def setUp(self):
|
||
"""Setup for each test case."""
|
||
# 创建一个基本的FDConfig对象
|
||
self.fd_config = Mock()
|
||
self.fd_config.scheduler_config = Mock()
|
||
self.fd_config.scheduler_config.max_num_seqs = 5
|
||
|
||
# 创建GuidedDecoding对象
|
||
self.guided_decoding = GuidedDecoding(self.fd_config)
|
||
|
||
# 创建一个模拟的LogitsProcessorBase
|
||
self.mock_processor = Mock(spec=LogitsProcessorBase)
|
||
self.mock_processor.is_terminated = False
|
||
self.mock_processor.reasoning_ended = True
|
||
self.mock_processor.enable_reasoning = False
|
||
|
||
# 模拟allocate_token_bitmask方法返回一个假的bitmask
|
||
self.mock_processor.allocate_token_bitmask.return_value = paddle.zeros([5, 10], dtype="int32")
|
||
|
||
# 模拟fill_token_bitmask方法
|
||
self.mock_processor.fill_token_bitmask.return_value = None
|
||
|
||
# 模拟accept_token方法返回True
|
||
self.mock_processor.accept_token.return_value = True
|
||
|
||
def test_init(self):
|
||
"""Test initialization."""
|
||
self.assertIsNone(self.guided_decoding.token_bitmask)
|
||
self.assertEqual(len(self.guided_decoding.logits_processors), 5)
|
||
self.assertIsNone(self.guided_decoding.reasoning_parser)
|
||
self.assertEqual(len(self.guided_decoding._prefill_done_idxs), 5)
|
||
self.assertEqual(len(self.guided_decoding._tokens_to_acc), 5)
|
||
|
||
def test_apply_reasoning_parser(self):
|
||
"""Test apply_reasoning_parser method."""
|
||
mock_parser = Mock(spec=ReasoningParser)
|
||
self.guided_decoding.apply_reasoning_parser(mock_parser)
|
||
self.assertEqual(self.guided_decoding.reasoning_parser, mock_parser)
|
||
|
||
def test_add_logits_processor_no_future(self):
|
||
"""Test add_logits_processor method without future."""
|
||
self.guided_decoding.add_logits_processor(0, None, [])
|
||
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
|
||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||
|
||
def test_add_logits_processor_with_prefill_tokens(self):
|
||
"""Test add_logits_processor method with prefill tokens."""
|
||
# 创建模拟Future对象
|
||
mock_future = Mock()
|
||
mock_future.done.return_value = True
|
||
mock_future.result.return_value = self.mock_processor
|
||
|
||
prefill_tokens = [1, 2, 3]
|
||
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
|
||
|
||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||
self.assertEqual(self.guided_decoding.logits_processors[0], self.mock_processor)
|
||
self.mock_processor.accept_token.assert_any_call(1)
|
||
self.mock_processor.accept_token.assert_any_call(2)
|
||
self.mock_processor.accept_token.assert_any_call(3)
|
||
|
||
def test_add_logits_processor_with_async_future(self):
|
||
"""Test add_logits_processor method with async future."""
|
||
# 创建模拟Future对象
|
||
mock_future = Mock()
|
||
mock_future.done.return_value = False
|
||
|
||
prefill_tokens = [1, 2, 3]
|
||
self.guided_decoding.add_logits_processor(0, mock_future, prefill_tokens)
|
||
|
||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||
self.assertEqual(self.guided_decoding.logits_processors[0], mock_future)
|
||
self.assertEqual(self.guided_decoding._tokens_to_acc[0], prefill_tokens)
|
||
|
||
def test_should_fill_bitmask_no_reasoning_parser(self):
|
||
"""Test should_fill_bitmask method with no reasoning parser."""
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||
|
||
def test_should_fill_bitmask_with_reasoning_parser(self):
|
||
"""Test should_fill_bitmask method with reasoning parser."""
|
||
mock_parser = Mock(spec=ReasoningParser)
|
||
self.guided_decoding.reasoning_parser = mock_parser
|
||
|
||
# 测试 enable_reasoning=True 的情况
|
||
self.mock_processor.enable_reasoning = True
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||
|
||
# 测试 enable_reasoning=False, reasoning_ended=False 的情况
|
||
self.mock_processor.enable_reasoning = False
|
||
self.mock_processor.reasoning_ended = False
|
||
self.assertFalse(self.guided_decoding.should_fill_bitmask(0))
|
||
|
||
# 测试 enable_reasoning=False, reasoning_ended=True 的情况
|
||
self.mock_processor.reasoning_ended = True
|
||
self.assertTrue(self.guided_decoding.should_fill_bitmask(0))
|
||
|
||
def test_reset_processor(self):
|
||
"""Test reset_processor method."""
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
|
||
self.guided_decoding.reset_processor(0)
|
||
|
||
self.assertFalse(self.guided_decoding._prefill_done_idxs[0])
|
||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||
|
||
def test_update_vocab_mask_with_new_prefill_done(self):
|
||
"""Test update_vocab_mask method with new prefill_done_idxs."""
|
||
# 设置索引0的处理器
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding._prefill_done_idxs[0] = False
|
||
|
||
# 调用update_vocab_mask并标记索引0为已完成
|
||
self.guided_decoding.update_vocab_mask([0])
|
||
|
||
# 验证_prefill_done_idxs[0]已更新
|
||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||
|
||
# 验证fill_token_bitmask被调用
|
||
self.mock_processor.fill_token_bitmask.assert_called_once()
|
||
|
||
def test_update_vocab_mask_with_future_processor(self):
|
||
"""Test update_vocab_mask method with future processor."""
|
||
# 创建模拟Future对象
|
||
mock_future = Mock()
|
||
|
||
# 设置索引0的处理器为Future
|
||
self.guided_decoding.logits_processors[0] = mock_future
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
|
||
# 调用update_vocab_mask
|
||
self.guided_decoding.update_vocab_mask([])
|
||
|
||
# 验证fill_token_bitmask没有被调用(因为处理器是Future)
|
||
self.mock_processor.fill_token_bitmask.assert_not_called()
|
||
|
||
def test_accept_tokens_from_prefill_node(self):
|
||
"""Test accept_tokens_from_prefill_node method."""
|
||
# 设置索引0的处理器和待接受的tokens
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
|
||
|
||
# 调用accept_tokens_from_prefill_node
|
||
self.guided_decoding.accept_tokens_from_prefill_node(0)
|
||
|
||
# 验证accept_token被调用了3次
|
||
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
|
||
self.mock_processor.accept_token.assert_any_call(1)
|
||
self.mock_processor.accept_token.assert_any_call(2)
|
||
self.mock_processor.accept_token.assert_any_call(3)
|
||
|
||
# 验证_tokens_to_acc[0]已被重置
|
||
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
|
||
|
||
@patch("fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask")
|
||
def test_apply_token_mask(self, mock_apply_token_mask):
|
||
"""Test apply_token_mask method."""
|
||
# 创建测试数据
|
||
logits = paddle.zeros([5, 10], dtype="float32")
|
||
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
|
||
|
||
# 设置索引0的处理器
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
|
||
# 调用apply_token_mask
|
||
result = self.guided_decoding.apply_token_mask(logits, [])
|
||
|
||
# 验证fill_token_bitmask没有被调用,非 Future
|
||
self.mock_processor.fill_token_bitmask.assert_not_called()
|
||
|
||
# 验证apply_token_mask被调用
|
||
mock_apply_token_mask.assert_called_once()
|
||
|
||
# 验证返回值
|
||
self.assertTrue((result == paddle.ones([5, 10], dtype="float32")).all())
|
||
|
||
def test_apply_token_mask_with_future_processor(self):
|
||
"""Test apply_token_mask method with future processor."""
|
||
# 创建测试数据
|
||
logits = paddle.zeros([5, 10], dtype="float32")
|
||
|
||
# 创建模拟Future对象
|
||
mock_future = Mock(spec=Future)
|
||
mock_future.done.return_value = True
|
||
mock_future.result.return_value = self.mock_processor
|
||
|
||
# 设置索引0的处理器为Future
|
||
self.guided_decoding.logits_processors[0] = mock_future
|
||
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
self.assertTrue(self.guided_decoding._prefill_done_idxs[0])
|
||
self.assertIsNotNone(self.guided_decoding.logits_processors[0])
|
||
self.assertTrue(isinstance(self.guided_decoding.logits_processors[0], Future))
|
||
self.guided_decoding._tokens_to_acc[0] = [1, 2, 3]
|
||
|
||
# 模拟patch apply_token_mask
|
||
with patch(
|
||
"fastdeploy.model_executor.guided_decoding.xgrammar_backend.apply_token_mask"
|
||
) as mock_apply_token_mask:
|
||
mock_apply_token_mask.return_value = paddle.ones([5, 10], dtype="float32")
|
||
|
||
# 调用apply_token_mask
|
||
self.guided_decoding.apply_token_mask(logits, [])
|
||
|
||
# 验证Future.result被调用
|
||
mock_future.result.assert_called_once()
|
||
|
||
# 验证accept_token被调用了3次
|
||
self.assertEqual(self.mock_processor.accept_token.call_count, 3)
|
||
|
||
# 验证_tokens_to_acc[0]已被重置
|
||
self.assertIsNone(self.guided_decoding._tokens_to_acc[0])
|
||
|
||
def test_accept_token(self):
|
||
"""Test _accept_token method."""
|
||
# 设置索引0的处理器
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
|
||
# 调用_accept_token
|
||
self.guided_decoding._accept_token(0, 1)
|
||
|
||
# 验证accept_token被调用
|
||
self.mock_processor.accept_token.assert_called_once_with(1)
|
||
|
||
def test_accept_token_with_reasoning_parser(self):
|
||
"""Test _accept_token method with reasoning parser."""
|
||
# 创建模拟ReasoningParser
|
||
mock_parser = Mock(spec=ReasoningParser)
|
||
mock_parser.is_reasoning_end.return_value = True
|
||
self.guided_decoding.reasoning_parser = mock_parser
|
||
|
||
# 设置索引0的处理器
|
||
self.mock_processor.enable_reasoning = False
|
||
self.mock_processor.reasoning_ended = False
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
|
||
# 调用_accept_token
|
||
self.guided_decoding._accept_token(0, 1)
|
||
|
||
# 验证is_reasoning_end被调用
|
||
mock_parser.is_reasoning_end.assert_called_once_with([1])
|
||
|
||
# 验证reasoning_ended已更新
|
||
self.assertTrue(self.mock_processor.reasoning_ended)
|
||
|
||
# 验证accept_token没有被调用(因为reasoning_ended刚被设置为True)
|
||
self.mock_processor.accept_token.assert_not_called()
|
||
|
||
def test_accept_token_processor_terminated(self):
|
||
"""Test _accept_token method when processor is terminated."""
|
||
# 设置索引0的处理器,并让accept_token返回False
|
||
self.mock_processor.accept_token.return_value = False
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
|
||
# 调用_accept_token
|
||
self.guided_decoding._accept_token(0, 1)
|
||
|
||
# 验证处理器被重置
|
||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||
|
||
def test_update_output_tokens(self):
|
||
"""Test update_output_tokens method."""
|
||
# 创建测试数据
|
||
next_tokens = paddle.to_tensor([[1], [2], [3], [4], [5]])
|
||
|
||
# 设置索引0和1的处理器
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding.logits_processors[1] = self.mock_processor
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
self.guided_decoding._prefill_done_idxs[1] = True
|
||
|
||
# 调用update_output_tokens
|
||
self.guided_decoding.update_output_tokens(next_tokens)
|
||
|
||
# 验证accept_token被调用了两次
|
||
self.assertEqual(self.mock_processor.accept_token.call_count, 2)
|
||
self.mock_processor.accept_token.assert_any_call(1)
|
||
self.mock_processor.accept_token.assert_any_call(2)
|
||
|
||
def test_update_output_tokens_with_negative_token(self):
|
||
"""Test update_output_tokens method with negative token."""
|
||
# 创建测试数据,包含负值
|
||
next_tokens = paddle.to_tensor([[-1], [2]])
|
||
|
||
# 设置索引0和1的处理器
|
||
self.guided_decoding.logits_processors[0] = self.mock_processor
|
||
self.guided_decoding.logits_processors[1] = self.mock_processor
|
||
self.guided_decoding._prefill_done_idxs[0] = True
|
||
self.guided_decoding._prefill_done_idxs[1] = True
|
||
|
||
# 调用update_output_tokens
|
||
self.guided_decoding.update_output_tokens(next_tokens)
|
||
|
||
# 验证索引0的处理器被重置
|
||
self.assertIsNone(self.guided_decoding.logits_processors[0])
|
||
|
||
# 验证索引1的处理器的accept_token被调用
|
||
self.mock_processor.accept_token.assert_called_once_with(2)
|
||
|
||
def test_pre_process(self):
|
||
"""Test pre_process method."""
|
||
# 模拟update_vocab_mask方法
|
||
with patch.object(self.guided_decoding, "update_vocab_mask") as mock_update_vocab_mask:
|
||
# 调用pre_process
|
||
self.guided_decoding.pre_process([0, 1])
|
||
|
||
# 验证update_vocab_mask被调用
|
||
mock_update_vocab_mask.assert_called_once_with([0, 1])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|