mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Guided Decoding add LLguidance backend (#5124)
* llguidance * add requirements_guided_decoding.txt and doc * fix test_guidance_*.py * fix test_guidance_*.py && mv * fix llguidance choice * test_guidance_* * rm lazy loader --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
178
tests/model_executor/guided_decoding/test_guidance_backend.py
Normal file
178
tests/model_executor/guided_decoding/test_guidance_backend.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastdeploy.model_executor.guided_decoding import BackendBase
|
||||
|
||||
mock_llguidance = MagicMock()
|
||||
mock_llguidancehf = MagicMock()
|
||||
mock_llguidancetorch = MagicMock()
|
||||
mock_torch = MagicMock()
|
||||
|
||||
setattr(mock_llguidance, "hf", mock_llguidancehf)
|
||||
|
||||
sys.modules["llguidance"] = mock_llguidance
|
||||
sys.modules["llguidance.hf"] = mock_llguidancehf
|
||||
sys.modules["llguidance.torch"] = mock_llguidancetorch
|
||||
sys.modules["torch"] = mock_torch
|
||||
|
||||
# Import the module to be tested
|
||||
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
|
||||
LLGuidanceBackend,
|
||||
LLGuidanceProcessor,
|
||||
process_for_additional_properties,
|
||||
)
|
||||
|
||||
|
||||
class TestProcessForAdditionalProperties(unittest.TestCase):
|
||||
def test_process_json_string(self):
|
||||
# Test string input
|
||||
json_str = '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
result = process_for_additional_properties(json_str)
|
||||
self.assertFalse(result["additionalProperties"])
|
||||
|
||||
def test_process_json_dict(self):
|
||||
# Test dictionary input
|
||||
json_dict = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
result = process_for_additional_properties(json_dict)
|
||||
self.assertFalse(result["additionalProperties"])
|
||||
# Ensure the original dictionary is not modified
|
||||
self.assertNotIn("additionalProperties", json_dict)
|
||||
|
||||
def test_nested_objects(self):
|
||||
# Test nested objects
|
||||
json_dict = {
|
||||
"type": "object",
|
||||
"properties": {"person": {"type": "object", "properties": {"name": {"type": "string"}}}},
|
||||
}
|
||||
result = process_for_additional_properties(json_dict)
|
||||
self.assertFalse(result["additionalProperties"])
|
||||
self.assertFalse(result["properties"]["person"]["additionalProperties"])
|
||||
|
||||
|
||||
@patch("llguidance.LLMatcher")
|
||||
@patch("llguidance.LLTokenizer")
|
||||
class TestLLGuidanceProcessor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.vocab_size = 100
|
||||
self.batch_size = 2
|
||||
|
||||
def test_initialization(self, mock_tokenizer, mock_matcher):
|
||||
# Test initialization
|
||||
processor = LLGuidanceProcessor(
|
||||
ll_matcher=mock_matcher,
|
||||
ll_tokenizer=mock_tokenizer,
|
||||
serialized_grammar="test_grammar",
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
self.assertEqual(processor.vocab_size, self.vocab_size)
|
||||
self.assertEqual(processor.batch_size, self.batch_size)
|
||||
self.assertFalse(processor.is_terminated)
|
||||
|
||||
def test_reset(self, mock_tokenizer, mock_matcher):
|
||||
# Test reset functionality
|
||||
processor = LLGuidanceProcessor(
|
||||
ll_matcher=mock_matcher,
|
||||
ll_tokenizer=mock_tokenizer,
|
||||
serialized_grammar="test_grammar",
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
processor.is_terminated = True
|
||||
processor.reset()
|
||||
|
||||
mock_matcher.reset.assert_called_once()
|
||||
self.assertFalse(processor.is_terminated)
|
||||
|
||||
def test_accept_token(self, mock_tokenizer, mock_matcher):
|
||||
# Test accept_token functionality
|
||||
mock_matcher.is_stopped.return_value = False
|
||||
mock_matcher.consume_tokens.return_value = True
|
||||
mock_tokenizer.eos_token = 1
|
||||
|
||||
processor = LLGuidanceProcessor(
|
||||
ll_matcher=mock_matcher,
|
||||
ll_tokenizer=mock_tokenizer,
|
||||
serialized_grammar="test_grammar",
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
# Normal token
|
||||
result = processor.accept_token(0)
|
||||
self.assertTrue(result)
|
||||
mock_matcher.consume_tokens.assert_called_with([0])
|
||||
|
||||
# EOS token
|
||||
result = processor.accept_token(1)
|
||||
self.assertTrue(result)
|
||||
self.assertTrue(processor.is_terminated)
|
||||
|
||||
|
||||
@patch("llguidance.LLMatcher")
|
||||
@patch("llguidance.hf.from_tokenizer")
|
||||
class TestLLGuidanceBackend(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Create a mock FDConfig
|
||||
self.fd_config = MagicMock()
|
||||
self.fd_config.model_config.vocab_size = 100
|
||||
self.fd_config.scheduler_config.max_num_seqs = 2
|
||||
self.fd_config.structured_outputs_config.disable_any_whitespace = False
|
||||
self.fd_config.structured_outputs_config.disable_additional_properties = False
|
||||
self.fd_config.structured_outputs_config.reasoning_parser = None
|
||||
|
||||
def test_initialization(self, mock_from_tokenizer, mock_matcher):
|
||||
# Test backend initialization
|
||||
mock_tokenizer = MagicMock()
|
||||
with patch.object(BackendBase, "_get_tokenizer_hf", return_value=mock_tokenizer):
|
||||
backend = LLGuidanceBackend(fd_config=self.fd_config)
|
||||
|
||||
self.assertEqual(backend.vocab_size, 100)
|
||||
self.assertEqual(backend.batch_size, 2)
|
||||
self.assertTrue(backend.any_whitespace)
|
||||
|
||||
@patch("llguidance.LLMatcher")
|
||||
def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher):
|
||||
# Test creating a processor
|
||||
with patch.object(LLGuidanceBackend, "__init__", return_value=None):
|
||||
backend = LLGuidanceBackend(fd_config=None) # Arguments are not important because __init__ is mocked
|
||||
|
||||
# Manually set all required attributes
|
||||
backend.hf_tokenizer = MagicMock()
|
||||
backend.ll_tokenizer = MagicMock()
|
||||
backend.vocab_size = 100
|
||||
backend.batch_size = 2
|
||||
backend.any_whitespace = True
|
||||
backend.disable_additional_properties = False
|
||||
|
||||
mock_matcher = MagicMock()
|
||||
mock_matcher_class.return_value = mock_matcher
|
||||
|
||||
processor = backend._create_processor("test_grammar")
|
||||
|
||||
self.assertIsInstance(processor, LLGuidanceProcessor)
|
||||
self.assertEqual(processor.vocab_size, 100)
|
||||
self.assertEqual(processor.batch_size, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
595
tests/model_executor/guided_decoding/test_guidance_checker.py
Normal file
595
tests/model_executor/guided_decoding/test_guidance_checker.py
Normal file
@@ -0,0 +1,595 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Check if llguidance can be imported
|
||||
HAS_LLGUIDANCE = False
|
||||
try:
|
||||
import llguidance
|
||||
|
||||
llguidance
|
||||
HAS_LLGUIDANCE = True
|
||||
except ImportError:
|
||||
mock_llguidance = MagicMock()
|
||||
mock_llguidancehf = MagicMock()
|
||||
mock_llguidancetorch = MagicMock()
|
||||
mock_torch = MagicMock()
|
||||
sys.modules["llguidance"] = mock_llguidance
|
||||
sys.modules["llguidance.hf"] = mock_llguidancehf
|
||||
sys.modules["llguidance.torch"] = mock_llguidancetorch
|
||||
sys.modules["torch"] = mock_torch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llguidance_checker():
|
||||
"""Return an LLGuidanceChecker instance for testing."""
|
||||
return LLGuidanceChecker()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llguidance_checker_with_options():
|
||||
"""Return an LLGuidanceChecker instance configured with specific options."""
|
||||
return LLGuidanceChecker(disable_any_whitespace=True)
|
||||
|
||||
|
||||
from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker
|
||||
|
||||
|
||||
def MockRequest():
|
||||
request = MagicMock()
|
||||
request.guided_json = None
|
||||
request.guided_json_object = None
|
||||
request.structural_tag = None
|
||||
request.guided_regex = None
|
||||
request.guided_choice = None
|
||||
request.guided_grammar = None
|
||||
return request
|
||||
|
||||
|
||||
class TestLLGuidanceCheckerMocked:
|
||||
"""Test LLGuidanceChecker using Mock, suitable for environments without the llguidance library."""
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, llguidance_checker):
|
||||
"""Test processing guided_json string type."""
|
||||
mock_from_schema.return_value = "serialized_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_from_schema.assert_called_once()
|
||||
assert grammar == "serialized_grammar"
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, llguidance_checker):
|
||||
"""Test processing guided_json dictionary type."""
|
||||
mock_from_schema.return_value = "serialized_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_from_schema.assert_called_once()
|
||||
assert isinstance(request.guided_json, dict) # Verify that the dictionary has been converted to a string
|
||||
assert grammar == "serialized_grammar"
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llguidance_checker):
|
||||
"""Test processing guided_json_object."""
|
||||
mock_from_schema.return_value = "serialized_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json_object = True
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_from_schema.assert_called_once()
|
||||
assert request.guided_json_object
|
||||
assert grammar == "serialized_grammar"
|
||||
|
||||
@patch("llguidance.grammar_from")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidance_checker):
|
||||
"""Test processing guided_regex."""
|
||||
mock_grammar_from.return_value = "serialized_regex_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_regex = "[a-zA-Z]+"
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_grammar_from.assert_called_once_with("regex", "[a-zA-Z]+")
|
||||
assert grammar == "serialized_regex_grammar"
|
||||
|
||||
@patch("llguidance.grammar_from")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguidance_checker):
|
||||
"""Test processing guided_choice."""
|
||||
mock_grammar_from.return_value = "serialized_choice_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_choice = ["option1", "option2"]
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_grammar_from.assert_called_once_with("choice", ["option1", "option2"])
|
||||
assert grammar == "serialized_choice_grammar"
|
||||
|
||||
@patch("llguidance.grammar_from")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguidance_checker):
|
||||
"""Test processing guided_grammar."""
|
||||
mock_grammar_from.return_value = "serialized_grammar_spec"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_grammar = "grammar specification"
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_grammar_from.assert_called_once_with("grammar", "grammar specification")
|
||||
assert grammar == "serialized_grammar_spec"
|
||||
|
||||
@patch("llguidance.StructTag")
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llguidance_checker):
|
||||
"""Test processing structural_tag."""
|
||||
# Configure mock objects
|
||||
mock_from_schema.return_value = "serialized_schema"
|
||||
mock_struct_tag.to_grammar.return_value = "serialized_structural_grammar"
|
||||
struct_tag_instance = MagicMock()
|
||||
mock_struct_tag.return_value = struct_tag_instance
|
||||
|
||||
request = MockRequest()
|
||||
request.structural_tag = {
|
||||
"triggers": ["<json>"],
|
||||
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
|
||||
}
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
mock_from_schema.assert_called_once()
|
||||
mock_struct_tag.assert_called_once()
|
||||
mock_struct_tag.to_grammar.assert_called_once()
|
||||
assert grammar == "serialized_structural_grammar"
|
||||
|
||||
@patch("llguidance.StructTag")
|
||||
def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguidance_checker):
|
||||
"""Test processing structural_tag when a trigger is missing."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {
|
||||
"triggers": ["<xml>"],
|
||||
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Trigger .* not found in triggers"):
|
||||
llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
@patch("llguidance.StructTag")
|
||||
def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguidance_checker):
|
||||
"""Test processing structural_tag when structures are empty."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {"triggers": ["<json>"], "structures": []}
|
||||
|
||||
with pytest.raises(ValueError, match="No structural tags found in the grammar spec"):
|
||||
llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
def test_serialize_invalid_grammar_type(self, llguidance_checker):
|
||||
"""Test processing invalid grammar types."""
|
||||
request = MockRequest()
|
||||
# No grammar type set
|
||||
|
||||
with pytest.raises(ValueError, match="grammar is not of valid supported types"):
|
||||
llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguidance_checker):
|
||||
"""Test schema_format method processing valid JSON."""
|
||||
mock_from_schema.return_value = "serialized_grammar"
|
||||
mock_validate.return_value = None
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object"}'
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
@patch("llguidance.LLMatcher.validate_grammar")
|
||||
def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, llguidance_checker):
|
||||
"""Test schema_format method processing invalid grammar."""
|
||||
mock_from_schema.return_value = "serialized_grammar"
|
||||
mock_validate.return_value = "Invalid grammar"
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object"}'
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Grammar error: Invalid grammar" in error
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_checker):
|
||||
"""Test schema_format method processing JSON decode error."""
|
||||
mock_from_schema.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = "{invalid json}"
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Invalid format for guided decoding" in error
|
||||
|
||||
@patch("llguidance.LLMatcher.grammar_from_json_schema")
|
||||
def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_checker):
|
||||
"""Test schema_format method processing unexpected errors."""
|
||||
mock_from_schema.side_effect = Exception("Unexpected error")
|
||||
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object"}'
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "An unexpected error occurred during schema validation" in error
|
||||
|
||||
def test_init_with_disable_whitespace(self, llguidance_checker_with_options):
|
||||
"""Test setting the disable_any_whitespace option during initialization."""
|
||||
assert llguidance_checker_with_options.any_whitespace is False
|
||||
assert llguidance_checker_with_options.disable_additional_properties is True
|
||||
assert LLGuidanceChecker(disable_any_whitespace=True).any_whitespace is False
|
||||
assert LLGuidanceChecker(disable_any_whitespace=False).any_whitespace is True
|
||||
|
||||
# default check
|
||||
from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL
|
||||
|
||||
assert FD_GUIDANCE_DISABLE_ADDITIONAL
|
||||
|
||||
assert LLGuidanceChecker().disable_additional_properties is True
|
||||
with patch("fastdeploy.model_executor.guided_decoding.guidance_backend.FD_GUIDANCE_DISABLE_ADDITIONAL", False):
|
||||
assert LLGuidanceChecker().disable_additional_properties is False
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance library not installed, skipping actual dependency tests")
|
||||
class TestLLGuidanceCheckerReal:
|
||||
"""Test using the actual llguidance library, suitable for development environments."""
|
||||
|
||||
def test_serialize_guided_json_string_real(self, llguidance_checker):
|
||||
"""Test processing guided_json string using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
# Verify if the returned grammar is a valid string
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
print("grammar", grammar)
|
||||
|
||||
def test_serialize_guided_json_dict_real(self, llguidance_checker):
|
||||
"""Test processing guided_json dictionary using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert isinstance(request.guided_json, dict)
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_serialize_guided_json_object_real(self, llguidance_checker):
|
||||
"""Test processing guided_json_object using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_json_object = True
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert request.guided_json_object
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_serialize_guided_regex_real(self, llguidance_checker):
|
||||
"""Test processing guided_regex using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_regex = "[a-zA-Z]+"
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_serialize_guided_choice_real(self, llguidance_checker):
|
||||
"""Test processing guided_choice using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_choice = ["option1", "option2"]
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_serialize_guided_grammar_real(self, llguidance_checker):
|
||||
"""Test processing guided_grammar using the actual library."""
|
||||
request = MockRequest()
|
||||
# Use a simple CFG grammar example
|
||||
request.guided_grammar = """
|
||||
root ::= greeting name
|
||||
greeting ::= "Hello" | "Hi"
|
||||
name ::= "world" | "everyone"
|
||||
"""
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_serialize_structural_tag_real(self, llguidance_checker):
|
||||
"""Test processing structural_tag using the actual library."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {
|
||||
"triggers": ["<json>"],
|
||||
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
|
||||
}
|
||||
|
||||
grammar = llguidance_checker.serialize_guidance_grammar(request)
|
||||
|
||||
assert isinstance(grammar, str)
|
||||
assert len(grammar) > 0
|
||||
|
||||
def test_schema_format_valid_json_real(self, llguidance_checker):
|
||||
"""Test schema_format method processing valid JSON using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
assert result_request.guided_json != '{"type": "object", "properties": {"name": {"type": "string"}}}'
|
||||
|
||||
def test_schema_format_invalid_json_real(self, llguidance_checker):
|
||||
"""Test schema_format method processing invalid JSON using the actual library."""
|
||||
request = MockRequest()
|
||||
request.guided_json = "{invalid json}"
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Invalid format for guided decoding" in error
|
||||
|
||||
def test_whitespace_flexibility_option_real(self):
|
||||
"""Test the impact of the whitespace flexibility option using the actual library."""
|
||||
# Create two instances with different configurations
|
||||
flexible = LLGuidanceChecker(disable_any_whitespace=False)
|
||||
strict = LLGuidanceChecker(disable_any_whitespace=True)
|
||||
|
||||
request_flexible = MockRequest()
|
||||
request_flexible.guided_json = '{"type": "object"}'
|
||||
|
||||
request_strict = MockRequest()
|
||||
request_strict.guided_json = '{"type": "object"}'
|
||||
|
||||
grammar_flexible = flexible.serialize_guidance_grammar(request_flexible)
|
||||
grammar_strict = strict.serialize_guidance_grammar(request_strict)
|
||||
print("grammar_flexible", grammar_flexible)
|
||||
print("grammar_strict", grammar_strict)
|
||||
|
||||
# Expect grammars generated by the two configurations to be different
|
||||
assert grammar_flexible != grammar_strict
|
||||
|
||||
def test_schema_format_guided_json_object_real(self, llguidance_checker):
|
||||
"""Test schema_format processing guided_json_object."""
|
||||
request = MockRequest()
|
||||
request.guided_json_object = True
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
def test_schema_format_guided_regex_real(self, llguidance_checker):
|
||||
"""Test schema_format processing valid regular expressions."""
|
||||
request = MockRequest()
|
||||
request.guided_regex = r"[a-zA-Z0-9]+"
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
assert result_request.guided_regex != r"[a-zA-Z0-9]+" # Should be converted to grammar format
|
||||
|
||||
def test_schema_format_invalid_guided_regex_real(self, llguidance_checker):
|
||||
"""Test schema_format processing invalid regular expressions."""
|
||||
request = MockRequest()
|
||||
request.guided_regex = r"[" # Invalid regular expression
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Invalid format for guided decoding" in error
|
||||
|
||||
def test_schema_format_guided_choice_real(self, llguidance_checker):
|
||||
"""Test schema_format processing guided_choice."""
|
||||
request = MockRequest()
|
||||
request.guided_choice = ["option1", "option2", "option3"]
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
assert result_request.guided_choice != [
|
||||
"option1",
|
||||
"option2",
|
||||
"option3",
|
||||
] # Should be converted to grammar format
|
||||
|
||||
def test_schema_format_guided_grammar_real(self, llguidance_checker):
|
||||
"""Test schema_format processing guided_grammar."""
|
||||
request = MockRequest()
|
||||
# Use the correct grammar format supported by LLGuidance
|
||||
request.guided_grammar = """
|
||||
start: number
|
||||
number: DIGIT+
|
||||
DIGIT: "0"|"1"|"2"|"3"|"4"|"5"|"6"|"7"|"8"|"9"
|
||||
"""
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
assert isinstance(result_request.guided_grammar, str)
|
||||
|
||||
def test_schema_format_structural_tag_real(self, llguidance_checker):
|
||||
"""Test schema_format processing structural_tag."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {
|
||||
"triggers": ["```json"],
|
||||
"structures": [
|
||||
{
|
||||
"begin": "```json",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"end": "```",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
def test_schema_format_structural_tag_string_real(self, llguidance_checker):
|
||||
"""Test schema_format processing structural_tag in string format."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = json.dumps(
|
||||
{
|
||||
"triggers": ["```json"],
|
||||
"structures": [
|
||||
{
|
||||
"begin": "```json",
|
||||
"schema": {"type": "object", "properties": {"name": {"type": "string"}}},
|
||||
"end": "```",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_checker):
|
||||
"""Test schema_format processing structural_tag with invalid triggers."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {
|
||||
"triggers": ["```xml"], # Trigger does not match begin
|
||||
"structures": [
|
||||
{
|
||||
"begin": "```json",
|
||||
"schema": {"type": "object"},
|
||||
"end": "```",
|
||||
} # Does not contain any prefix from triggers here
|
||||
],
|
||||
}
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Invalid format for guided decoding" in error
|
||||
|
||||
def test_schema_format_structural_tag_empty_structures_real(self, llguidance_checker):
|
||||
"""Test schema_format processing structural_tag with empty structures."""
|
||||
request = MockRequest()
|
||||
request.structural_tag = {"triggers": ["```json"], "structures": []} # Empty structure
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "Invalid format for guided decoding" in error
|
||||
|
||||
def test_schema_format_json_dict_real(self, llguidance_checker):
|
||||
"""Test schema_format processing guided_json in dictionary format."""
|
||||
request = MockRequest()
|
||||
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
def test_schema_format_disable_additional_properties_real(self):
|
||||
"""Test schema_format processing disable_additional_properties parameter."""
|
||||
checker = LLGuidanceChecker(disable_additional_properties=True)
|
||||
request = MockRequest()
|
||||
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
|
||||
|
||||
result_request, error = checker.schema_format(request)
|
||||
|
||||
assert error is None
|
||||
assert result_request is request
|
||||
|
||||
def test_schema_format_unexpected_error_real(self, monkeypatch, llguidance_checker):
|
||||
"""Test schema_format processing unexpected errors."""
|
||||
request = MockRequest()
|
||||
request.guided_json = '{"type": "object"}'
|
||||
|
||||
# Mock unexpected exception
|
||||
def mock_serialize_grammar(*args, **kwargs):
|
||||
raise Exception("Unexpected error")
|
||||
|
||||
monkeypatch.setattr(llguidance_checker, "serialize_guidance_grammar", mock_serialize_grammar)
|
||||
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
|
||||
assert error is not None
|
||||
assert "An unexpected error occurred during schema validation" in error
|
||||
|
||||
def test_schema_format_no_valid_grammar_real(self, llguidance_checker):
|
||||
"""Test schema_format processing requests without valid grammar."""
|
||||
request = MockRequest()
|
||||
# No grammar-related attributes set
|
||||
|
||||
with pytest.raises(ValueError, match="grammar is not of valid supported types"):
|
||||
llguidance_checker.serialize_guidance_grammar(request)
|
||||
result_request, error = llguidance_checker.schema_format(request)
|
||||
assert error is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
172
tests/model_executor/guided_decoding/test_guidance_processor.py
Normal file
172
tests/model_executor/guided_decoding/test_guidance_processor.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# --- Mocking Setup ---
|
||||
# Prioritize mocking these lazy-loaded modules to facilitate testing in environments where these libraries are not installed.
|
||||
mock_torch = MagicMock()
|
||||
mock_llguidance = MagicMock()
|
||||
mock_llguidance_hf = MagicMock()
|
||||
mock_llguidance_torch = MagicMock()
|
||||
|
||||
mock_torch.__spec__ = MagicMock()
|
||||
mock_torch.distributed = MagicMock()
|
||||
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["llguidance"] = mock_llguidance
|
||||
sys.modules["llguidance.hf"] = mock_llguidance_hf
|
||||
sys.modules["llguidance.torch"] = mock_llguidance_torch
|
||||
|
||||
# Import the module to be tested after the mock setup is complete
|
||||
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
|
||||
LLGuidanceProcessor,
|
||||
)
|
||||
|
||||
|
||||
def MockFDConfig():
|
||||
"""Create a mock FDConfig object for testing"""
|
||||
config = MagicMock()
|
||||
# --- Fix point 1: Explicitly set model as a string to pass HF validation ---
|
||||
config.model_config.model = "test-model-path"
|
||||
config.model_config.architectures = [] # Set to empty list to prevent errors when iterating over the Mock
|
||||
|
||||
config.model_config.vocab_size = 1000
|
||||
config.scheduler_config.max_num_seqs = 4
|
||||
config.structured_outputs_config.disable_any_whitespace = False
|
||||
# Ensure the backend check logic passes
|
||||
config.structured_outputs_config.guided_decoding_backend = "guidance"
|
||||
return config
|
||||
|
||||
|
||||
def MockHFTokenizer():
|
||||
"""Create a mock Hugging Face Tokenizer object for testing"""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestLLGuidanceProcessorMocked(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for LLGuidanceProcessor using Mock.
|
||||
This test class is suitable for environments where the llguidance library is not installed.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up a new LLGuidanceProcessor instance for each test case"""
|
||||
self.mock_matcher = MagicMock()
|
||||
self.mock_tokenizer = MagicMock()
|
||||
self.mock_tokenizer.eos_token = 2 # Example EOS token ID
|
||||
self.processor = LLGuidanceProcessor(
|
||||
ll_matcher=self.mock_matcher,
|
||||
ll_tokenizer=self.mock_tokenizer,
|
||||
serialized_grammar="test_grammar",
|
||||
vocab_size=1000,
|
||||
batch_size=4,
|
||||
enable_thinking=False,
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
"""Test the constructor of LLGuidanceProcessor"""
|
||||
self.assertIs(self.processor.matcher, self.mock_matcher)
|
||||
self.assertEqual(self.processor.vocab_size, 1000)
|
||||
self.assertEqual(self.processor.batch_size, 4)
|
||||
self.assertFalse(self.processor.is_terminated)
|
||||
|
||||
@patch("fastdeploy.utils.llm_logger.warning")
|
||||
def test_check_error_logs_warning_once(self, mock_log_warning):
|
||||
"""Test that the _check_error method logs a warning when the matcher errors, and only logs it once"""
|
||||
self.mock_matcher.get_error.return_value = "A test error."
|
||||
|
||||
# First call, should log the message
|
||||
self.processor._check_error()
|
||||
mock_log_warning.assert_called_once_with("LLGuidance Matcher error: A test error.")
|
||||
|
||||
# Second call, should not log repeatedly
|
||||
self.processor._check_error()
|
||||
mock_log_warning.assert_called_once()
|
||||
|
||||
@patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance.torch")
|
||||
def test_allocate_token_bitmask(self, mock_backend_torch):
|
||||
"""
|
||||
Test the allocation of token bitmask.
|
||||
Note: We patch the llguidance_torch variable imported in the guidance_backend module here,
|
||||
instead of the global mock in sys.modules, to resolve inconsistent references caused by LazyLoader.
|
||||
"""
|
||||
mock_backend_torch.allocate_token_bitmask.return_value = "fake_bitmask_tensor"
|
||||
|
||||
result = self.processor.allocate_token_bitmask()
|
||||
|
||||
mock_backend_torch.allocate_token_bitmask.assert_called_once_with(4, 1000)
|
||||
self.assertEqual(result, "fake_bitmask_tensor")
|
||||
|
||||
@patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance.torch")
|
||||
def test_fill_token_bitmask(self, mock_backend_torch):
|
||||
"""Test the filling of token bitmask"""
|
||||
mock_bitmask = MagicMock()
|
||||
|
||||
self.processor.fill_token_bitmask(mock_bitmask, idx=2)
|
||||
|
||||
mock_backend_torch.fill_next_token_bitmask.assert_called_once_with(self.mock_matcher, mock_bitmask, 2)
|
||||
self.mock_matcher.get_error.assert_called_once()
|
||||
|
||||
def test_reset(self):
|
||||
"""Test the state reset of the processor"""
|
||||
self.processor.is_terminated = True
|
||||
self.processor._printed_error = True
|
||||
self.mock_matcher.get_error.return_value = ""
|
||||
|
||||
self.processor.reset()
|
||||
|
||||
self.mock_matcher.reset.assert_called_once()
|
||||
self.assertFalse(self.processor.is_terminated)
|
||||
self.assertFalse(self.processor._printed_error)
|
||||
|
||||
def test_accept_token_when_terminated(self):
|
||||
"""Test that accept_token returns False immediately when status is is_terminated"""
|
||||
self.processor.is_terminated = True
|
||||
self.assertFalse(self.processor.accept_token(123))
|
||||
|
||||
def test_accept_token_when_matcher_stopped(self):
|
||||
"""Test that accept_token returns False and updates status when the matcher is stopped"""
|
||||
self.mock_matcher.is_stopped.return_value = True
|
||||
self.assertTrue(self.processor.accept_token(123))
|
||||
self.assertFalse(self.processor.is_terminated)
|
||||
|
||||
def test_accept_token_is_eos(self):
|
||||
"""Test the behavior when an EOS token is received"""
|
||||
self.mock_matcher.is_stopped.return_value = False
|
||||
self.assertTrue(self.processor.accept_token(self.mock_tokenizer.eos_token))
|
||||
self.assertTrue(self.processor.is_terminated)
|
||||
|
||||
def test_accept_token_consumes_and_succeeds(self):
|
||||
"""Test successfully consuming a token"""
|
||||
self.mock_matcher.is_stopped.return_value = False
|
||||
self.mock_matcher.consume_tokens.return_value = True
|
||||
self.assertTrue(self.processor.accept_token(123))
|
||||
self.mock_matcher.consume_tokens.assert_called_once_with([123])
|
||||
self.mock_matcher.get_error.assert_called_once()
|
||||
|
||||
def test_accept_token_consumes_and_fails(self):
|
||||
"""Test failing to consume a token"""
|
||||
self.mock_matcher.is_stopped.return_value = False
|
||||
self.mock_matcher.consume_tokens.return_value = False
|
||||
self.assertFalse(self.processor.accept_token(123))
|
||||
self.mock_matcher.consume_tokens.assert_called_once_with([123])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user