[Feature] Guided Decoding add LLguidance backend (#5124)

* llguidance

* add requirements_guided_decoding.txt and doc

* fix test_guidance_*.py

* fix test_guidance_*.py && mv

* fix llguidance choice

* test_guidance_*

* rm lazy loader

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
Daci
2025-12-03 20:23:57 +08:00
committed by GitHub
parent 4e8096bd0d
commit 83dbc4e5dd
14 changed files with 1307 additions and 8 deletions

View File

@@ -7,6 +7,7 @@
Structured Outputs refer to predefined format constraints that force large language models to generate content strictly following specified structures. This feature significantly improves output controllability and is suitable for scenarios requiring precise format outputs (such as API calls, data parsing, code generation, etc.), while supporting dynamic grammar extensions to balance flexibility and standardization.
FastDeploy supports using the [XGrammar](https://xgrammar.mlc.ai/docs/) backend to generate structured outputs.
FastDeploy supports using the [LLguidance](https://github.com/guidance-ai/llguidance) backend to generate structured outputs.
Supported output formats:

View File

@@ -44,7 +44,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```disable_sequence_parallel_moe``` | `bool` | Disable sequence parallel moe, default: False |
| ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` |
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `guidance`, `off`, default: `off` |
| ```guided_decoding_disable_any_whitespace``` | `bool` | Whether to disable whitespace generation during guided decoding, default: False |
| ```speculative_config``` | `dict[str]` | Speculative decoding configuration, only supports standard format JSON string, default: None |
| ```dynamic_load_weight``` | `int` | Whether to enable dynamic weight loading, default: 0 |

View File

@@ -7,6 +7,7 @@
Structured Outputs 是指通过预定义格式约束使大模型生成内容严格遵循指定结构。该功能可显著提升生成结果的可控性适用于需要精确格式输出的场景如API调用、数据解析、代码生成等同时支持动态语法扩展平衡灵活性与规范性。
FastDeploy 支持使用 [XGrammar](https://xgrammar.mlc.ai/docs/) 后端生成结构化输出。
FastDeploy 支持使用 [LLguidance](https://github.com/guidance-ai/llguidance) 后端生成结构化输出。
支持输出格式

View File

@@ -42,7 +42,7 @@
| ```disable_sequence_parallel_moe``` | `bool` | 禁止在TP+EP中使用序列并行优化, default: False |
| ```splitwise_role``` | `str` | 是否开启splitwise推理默认值mixed 支持参数为["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 仅单机PD分离需要默认值None |
| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端支持 `auto`、`xgrammar`、`off`, 默认为 `off` |
| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端支持 `auto`、`xgrammar`、 `guidance`、`off`, 默认为 `off` |
| ```guided_decoding_disable_any_whitespace``` | `bool` | guided decoding期间是否禁止生成空格默认False |
| ```speculative_config``` | `dict[str]` | 投机解码配置仅支持标准格式json字符串默认为None |
| ```dynamic_load_weight``` | `int` | 是否动态加载权重默认0 |

View File

@@ -1664,13 +1664,27 @@ class FDConfig:
if (
self.structured_outputs_config is not None
and self.structured_outputs_config.guided_decoding_backend == "auto"
and self.structured_outputs_config.guided_decoding_backend != "off"
):
if current_platform.is_xpu() or self.speculative_config.method is not None:
logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
self.structured_outputs_config.guided_decoding_backend = "off"
else:
elif self.structured_outputs_config.guided_decoding_backend in ["auto", "xgrammar"]:
self.structured_outputs_config.guided_decoding_backend = "xgrammar"
elif self.structured_outputs_config.guided_decoding_backend == "guidance":
try:
import llguidance.torch
llguidance.torch
except ImportError:
raise ImportError(
"The 'llguidance' package is required for using guidance as the guided decoding backend. "
"Please install it via the appropriate method."
)
else:
raise NotImplementedError(
f"Guided decoding backend '{self.structured_outputs_config.guided_decoding_backend}' is not implemented. [auto, xgrammar, guidance, off]"
)
if self.model_config.enable_mm:
if self.cache_config.max_encoder_cache is None or self.cache_config.max_encoder_cache < 0:
@@ -1790,7 +1804,8 @@ class FDConfig:
"XGrammar",
"auto",
"off",
], f"Only support xgrammar、auto guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}."
"guidance",
], f"Only support [auto, xgrammar, guidance, off] guided decoding backend, but got {self.structured_outputs_config.guided_decoding_backend}."
if self.structured_outputs_config.guided_decoding_backend != "off":
# TODO: speculative decoding support guided_decoding

View File

@@ -148,6 +148,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
"FD_GUIDANCE_DISABLE_ADDITIONAL": lambda: bool(int(os.getenv("FD_GUIDANCE_DISABLE_ADDITIONAL", "1"))),
"FD_LLGUIDANCE_LOG_LEVEL": lambda: int(os.getenv("FD_LLGUIDANCE_LOG_LEVEL", "0")),
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),

View File

@@ -50,6 +50,15 @@ def get_guided_backend(
fd_config=fd_config,
**kwargs,
)
elif fd_config.structured_outputs_config.guided_decoding_backend.lower() == "guidance":
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
LLGuidanceBackend,
)
return LLGuidanceBackend(
fd_config=fd_config,
**kwargs,
)
else:
raise ValueError(
f"Get unsupported backend {fd_config.structured_outputs_config.guided_decoding_backend},"
@@ -77,5 +86,11 @@ def schema_checker(backend_name: str, **kwargs):
)
return XGrammarChecker(**kwargs)
elif backend_name.lower() == "guidance":
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
LLGuidanceChecker,
)
return LLGuidanceChecker(**kwargs)
else:
raise ValueError(f"Get unsupported backend {backend_name}, please check your configuration.")

View File

@@ -294,7 +294,12 @@ class BackendBase:
"""
try:
architectures = self.fd_config.model_config.architectures
if not ErnieArchitectures.contains_ernie_arch(architectures):
is_guidance_backend = (
self.fd_config.structured_outputs_config is not None
and self.fd_config.structured_outputs_config.guided_decoding_backend is not None
and self.fd_config.structured_outputs_config.guided_decoding_backend == "guidance"
)
if not ErnieArchitectures.contains_ernie_arch(architectures) or is_guidance_backend:
from transformers import AutoTokenizer, PreTrainedTokenizerFast
tokenizer = AutoTokenizer.from_pretrained(

View File

@@ -0,0 +1,314 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import copy
import json
import traceback
from typing import Any, Optional, Tuple, Union
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL, FD_LLGUIDANCE_LOG_LEVEL
from fastdeploy.model_executor.guided_decoding import (
BackendBase,
BaseChecker,
LogitsProcessorBase,
)
from fastdeploy.utils import llm_logger
class LLGuidanceProcessor(LogitsProcessorBase):
"""
LLGuidance-specific implementation of LogitsProcessorBase.
This processor enforces grammar constraints during token generation using llguidance.
It manages the grammar matching state and applies token masks to logits.
"""
def __init__(
self,
ll_matcher: llguidance.LLMatcher,
ll_tokenizer: llguidance.LLTokenizer,
serialized_grammar: str,
vocab_size: int,
batch_size: int,
enable_thinking: bool = False,
):
super().__init__(enable_reasoning=enable_thinking)
self.matcher = ll_matcher
self.ll_tokenizer = ll_tokenizer
self.serialized_grammar = serialized_grammar
self.vocab_size = vocab_size
self.batch_size = batch_size
self.is_terminated: bool = False
self._printed_error: bool = False
def _check_error(self):
"""Checks for and logs any errors from the LLMatcher."""
if not self._printed_error:
err = self.matcher.get_error()
if err:
self._printed_error = True
llm_logger.warning(f"LLGuidance Matcher error: {err}")
def allocate_token_bitmask(self) -> torch.Tensor:
"""
Allocate a token bitmask tensor for grammar constraints.
"""
return llguidance.torch.allocate_token_bitmask(self.batch_size, self.vocab_size)
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
"""
Fill the token bitmask with allowed tokens for the given index.
This will automatically provide an EOS mask if the matcher is stopped.
"""
llguidance.torch.fill_next_token_bitmask(self.matcher, token_bitmask, idx)
self._check_error()
def reset(self) -> None:
"""
Reset the grammar matcher state to initial conditions.
"""
self.matcher.reset()
self.is_terminated = False
self._printed_error = False
self._check_error()
def accept_token(self, token: int) -> bool:
"""
Validate and accept a generated token against the grammar constraints.
Returns True if the token is accepted, False otherwise.
"""
if self.is_terminated:
return False
if self.ll_tokenizer.eos_token == token:
self.is_terminated = True
return True
result = self.matcher.consume_tokens([token])
self._check_error()
return result
class LLGuidanceBackend(BackendBase):
"""
LLGuidance-specific implementation of BackendBase.
This backend handles the compilation of various schema types (JSON, regex, etc.)
into LLGuidance processors.
"""
def __init__(self, fd_config: FDConfig, **kwargs):
super().__init__(fd_config=fd_config)
self.vocab_size = fd_config.model_config.vocab_size
self.batch_size = fd_config.scheduler_config.max_num_seqs
self.any_whitespace = not fd_config.structured_outputs_config.disable_any_whitespace
llm_logger.info(f"LLGuidanceBackend vocab_size={self.vocab_size} batch_size={self.batch_size}")
try:
self.ll_tokenizer = llguidance.hf.from_tokenizer(self.hf_tokenizer, self.vocab_size)
except Exception as e:
import traceback
raise RuntimeError(
f"Failed to initialize llguidance tokenizer from HuggingFace tokenizer: {e} {traceback.format_exc()}"
)
def _create_processor(
self,
compiled_grammar: str,
enable_thinking: bool = False,
) -> Optional[LLGuidanceProcessor]:
"""
Create a logits processor instance for the given grammar schemata.
"""
try:
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
compiled_grammar,
log_level=FD_LLGUIDANCE_LOG_LEVEL,
)
return LLGuidanceProcessor(
ll_matcher=ll_matcher,
ll_tokenizer=self.ll_tokenizer,
serialized_grammar=compiled_grammar,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
enable_thinking=enable_thinking,
)
except Exception as e:
llm_logger.error(f"Failed to create llguidance processor: {e}, {str(traceback.format_exc())}")
return None
def _json_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]:
return self._create_processor(compiled_grammar, enable_thinking)
def _regex_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]:
return self._create_processor(compiled_grammar, enable_thinking)
def _grammar_processor(
self, compiled_grammar: str, enable_thinking: bool = False
) -> Optional[LLGuidanceProcessor]:
return self._create_processor(compiled_grammar, enable_thinking)
def _structural_tag_processor(
self, compiled_grammar: str, enable_thinking: bool = False
) -> Optional[LLGuidanceProcessor]:
return self._create_processor(compiled_grammar, enable_thinking)
def _walk_json_for_additional_properties(data: object):
if isinstance(data, dict):
for value in data.values():
_walk_json_for_additional_properties(value)
if "additionalProperties" not in data and ("properties" in data or "patternProperties" in data):
data["additionalProperties"] = False
elif isinstance(data, list):
for item in data:
_walk_json_for_additional_properties(item)
def process_for_additional_properties(guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
if isinstance(guide_json, str):
guide_json_obj = json.loads(guide_json)
else:
# copy for modifications
guide_json_obj = copy.deepcopy(guide_json)
_walk_json_for_additional_properties(guide_json_obj)
return guide_json_obj
class LLGuidanceChecker(BaseChecker):
"""
LLGuidance-specific implementation of BaseChecker.
This checker validates various schema types for compatibility with the
llguidance library before processing.
"""
def __init__(self, **kwargs):
super().__init__()
# Although the backend handles serialization, we can perform a quick
# static check here without a full tokenizer.
self.any_whitespace = not kwargs.get("disable_any_whitespace", False)
self.disable_additional_properties = FD_GUIDANCE_DISABLE_ADDITIONAL
"""If `True`, the `guidance` backend will not use `additionalProperties`
in the JSON schema. This is only supported for the `guidance` backend and
is used to better align its behaviour with `outlines` and `xgrammar`."""
def serialize_guidance_grammar(self, request: Request):
def _process_schema(
grammar_spec: Union[str, dict[str, Any]],
) -> str:
if self.disable_additional_properties:
grammar_spec = process_for_additional_properties(grammar_spec)
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,
defaults={
"whitespace_flexible": self.any_whitespace,
},
)
if request.guided_json:
if isinstance(request.guided_json, dict):
guided_json = json.dumps(request.guided_json)
else:
guided_json = request.guided_json
return _process_schema(guided_json)
elif request.guided_json_object:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
defaults={
"whitespace_flexible": self.any_whitespace,
},
)
if request.structural_tag:
if isinstance(request.structural_tag, str):
s_tag = json.loads(request.structural_tag)
else:
s_tag = request.structural_tag
triggers: list[str] = s_tag["triggers"]
tags: list[llguidance.StructTag] = []
for s in s_tag["structures"]:
begin: str = s["begin"]
trig = next((t for t in triggers if begin.startswith(t)), None)
if trig is None:
raise ValueError(f"Trigger {begin} not found in triggers {triggers}")
tags.append(
llguidance.StructTag(
trigger=trig,
begin=s["begin"],
grammar=_process_schema(s["schema"]),
end=s["end"],
)
)
if not tags:
raise ValueError("No structural tags found in the grammar spec.")
return llguidance.StructTag.to_grammar(tags)
if request.guided_regex:
tp = "regex"
grammar_spec = request.guided_regex
elif request.guided_choice:
tp = "choice"
grammar_spec = request.guided_choice
elif request.guided_grammar:
tp = "grammar"
grammar_spec = request.guided_grammar
else:
llm_logger.error("Validation should have already occurred. " "Please file an issue.")
raise ValueError("grammar is not of valid supported types. ")
return llguidance.grammar_from(tp, grammar_spec)
def schema_format(self, request: Request) -> Tuple[Request, Optional[str]]:
"""
Validates and formats the schema for the LLGuidance backend.
"""
try:
guidance_grm = self.serialize_guidance_grammar(request)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, None)
if err:
raise ValueError(f"Grammar error: {err}")
else:
llm_logger.info(f"valid schema_format {guidance_grm} {request}")
if request.guided_regex:
request.guided_regex = guidance_grm
elif request.guided_choice:
request.guided_grammar = guidance_grm
request.guided_choice = None
elif request.guided_grammar:
request.guided_grammar = guidance_grm
elif request.guided_json:
request.guided_json = guidance_grm
except (ValueError, TypeError, json.JSONDecodeError) as e:
err_msg = f"Invalid format for guided decoding: {e!s} request={request}"
return request, err_msg
except Exception as e:
err_msg = f"An unexpected error occurred during schema validation: {e!s}"
return request, err_msg
return request, None

View File

@@ -73,7 +73,6 @@ class XGrammarProcessor(LogitsProcessorBase):
enable_thinking: bool = False,
):
super().__init__(enable_reasoning=enable_thinking)
self.max_rollback_tokens = 200
self.vocab_size = vocab_size
self.batch_size = batch_size
self.compiled_grammar = compiled_grammar
@@ -82,7 +81,6 @@ class XGrammarProcessor(LogitsProcessorBase):
self.matcher = GrammarMatcher(
compiled_grammar=compiled_grammar,
max_rollback_tokens=self.max_rollback_tokens,
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
)

View File

@@ -0,0 +1,3 @@
xgrammar==0.1.25
llguidance==1.3.0
torch==2.8.0

View File

@@ -0,0 +1,178 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import sys
import unittest
from unittest.mock import MagicMock, patch
from fastdeploy.model_executor.guided_decoding import BackendBase
mock_llguidance = MagicMock()
mock_llguidancehf = MagicMock()
mock_llguidancetorch = MagicMock()
mock_torch = MagicMock()
setattr(mock_llguidance, "hf", mock_llguidancehf)
sys.modules["llguidance"] = mock_llguidance
sys.modules["llguidance.hf"] = mock_llguidancehf
sys.modules["llguidance.torch"] = mock_llguidancetorch
sys.modules["torch"] = mock_torch
# Import the module to be tested
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
LLGuidanceBackend,
LLGuidanceProcessor,
process_for_additional_properties,
)
class TestProcessForAdditionalProperties(unittest.TestCase):
def test_process_json_string(self):
# Test string input
json_str = '{"type": "object", "properties": {"name": {"type": "string"}}}'
result = process_for_additional_properties(json_str)
self.assertFalse(result["additionalProperties"])
def test_process_json_dict(self):
# Test dictionary input
json_dict = {"type": "object", "properties": {"name": {"type": "string"}}}
result = process_for_additional_properties(json_dict)
self.assertFalse(result["additionalProperties"])
# Ensure the original dictionary is not modified
self.assertNotIn("additionalProperties", json_dict)
def test_nested_objects(self):
# Test nested objects
json_dict = {
"type": "object",
"properties": {"person": {"type": "object", "properties": {"name": {"type": "string"}}}},
}
result = process_for_additional_properties(json_dict)
self.assertFalse(result["additionalProperties"])
self.assertFalse(result["properties"]["person"]["additionalProperties"])
@patch("llguidance.LLMatcher")
@patch("llguidance.LLTokenizer")
class TestLLGuidanceProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 100
self.batch_size = 2
def test_initialization(self, mock_tokenizer, mock_matcher):
# Test initialization
processor = LLGuidanceProcessor(
ll_matcher=mock_matcher,
ll_tokenizer=mock_tokenizer,
serialized_grammar="test_grammar",
vocab_size=self.vocab_size,
batch_size=self.batch_size,
)
self.assertEqual(processor.vocab_size, self.vocab_size)
self.assertEqual(processor.batch_size, self.batch_size)
self.assertFalse(processor.is_terminated)
def test_reset(self, mock_tokenizer, mock_matcher):
# Test reset functionality
processor = LLGuidanceProcessor(
ll_matcher=mock_matcher,
ll_tokenizer=mock_tokenizer,
serialized_grammar="test_grammar",
vocab_size=self.vocab_size,
batch_size=self.batch_size,
)
processor.is_terminated = True
processor.reset()
mock_matcher.reset.assert_called_once()
self.assertFalse(processor.is_terminated)
def test_accept_token(self, mock_tokenizer, mock_matcher):
# Test accept_token functionality
mock_matcher.is_stopped.return_value = False
mock_matcher.consume_tokens.return_value = True
mock_tokenizer.eos_token = 1
processor = LLGuidanceProcessor(
ll_matcher=mock_matcher,
ll_tokenizer=mock_tokenizer,
serialized_grammar="test_grammar",
vocab_size=self.vocab_size,
batch_size=self.batch_size,
)
# Normal token
result = processor.accept_token(0)
self.assertTrue(result)
mock_matcher.consume_tokens.assert_called_with([0])
# EOS token
result = processor.accept_token(1)
self.assertTrue(result)
self.assertTrue(processor.is_terminated)
@patch("llguidance.LLMatcher")
@patch("llguidance.hf.from_tokenizer")
class TestLLGuidanceBackend(unittest.TestCase):
def setUp(self):
# Create a mock FDConfig
self.fd_config = MagicMock()
self.fd_config.model_config.vocab_size = 100
self.fd_config.scheduler_config.max_num_seqs = 2
self.fd_config.structured_outputs_config.disable_any_whitespace = False
self.fd_config.structured_outputs_config.disable_additional_properties = False
self.fd_config.structured_outputs_config.reasoning_parser = None
def test_initialization(self, mock_from_tokenizer, mock_matcher):
# Test backend initialization
mock_tokenizer = MagicMock()
with patch.object(BackendBase, "_get_tokenizer_hf", return_value=mock_tokenizer):
backend = LLGuidanceBackend(fd_config=self.fd_config)
self.assertEqual(backend.vocab_size, 100)
self.assertEqual(backend.batch_size, 2)
self.assertTrue(backend.any_whitespace)
@patch("llguidance.LLMatcher")
def test_create_processor(self, mock_matcher_class, mock_from_tokenizer, mock_matcher):
# Test creating a processor
with patch.object(LLGuidanceBackend, "__init__", return_value=None):
backend = LLGuidanceBackend(fd_config=None) # Arguments are not important because __init__ is mocked
# Manually set all required attributes
backend.hf_tokenizer = MagicMock()
backend.ll_tokenizer = MagicMock()
backend.vocab_size = 100
backend.batch_size = 2
backend.any_whitespace = True
backend.disable_additional_properties = False
mock_matcher = MagicMock()
mock_matcher_class.return_value = mock_matcher
processor = backend._create_processor("test_grammar")
self.assertIsInstance(processor, LLGuidanceProcessor)
self.assertEqual(processor.vocab_size, 100)
self.assertEqual(processor.batch_size, 2)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,595 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import sys
import unittest
from unittest.mock import MagicMock, patch
import pytest
# Check if llguidance can be imported
HAS_LLGUIDANCE = False
try:
import llguidance
llguidance
HAS_LLGUIDANCE = True
except ImportError:
mock_llguidance = MagicMock()
mock_llguidancehf = MagicMock()
mock_llguidancetorch = MagicMock()
mock_torch = MagicMock()
sys.modules["llguidance"] = mock_llguidance
sys.modules["llguidance.hf"] = mock_llguidancehf
sys.modules["llguidance.torch"] = mock_llguidancetorch
sys.modules["torch"] = mock_torch
@pytest.fixture
def llguidance_checker():
"""Return an LLGuidanceChecker instance for testing."""
return LLGuidanceChecker()
@pytest.fixture
def llguidance_checker_with_options():
"""Return an LLGuidanceChecker instance configured with specific options."""
return LLGuidanceChecker(disable_any_whitespace=True)
from fastdeploy.model_executor.guided_decoding.guidance_backend import LLGuidanceChecker
def MockRequest():
request = MagicMock()
request.guided_json = None
request.guided_json_object = None
request.structural_tag = None
request.guided_regex = None
request.guided_choice = None
request.guided_grammar = None
return request
class TestLLGuidanceCheckerMocked:
"""Test LLGuidanceChecker using Mock, suitable for environments without the llguidance library."""
@patch("llguidance.LLMatcher.grammar_from_json_schema")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_json_as_string(self, mock_validate, mock_from_schema, llguidance_checker):
"""Test processing guided_json string type."""
mock_from_schema.return_value = "serialized_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_from_schema.assert_called_once()
assert grammar == "serialized_grammar"
@patch("llguidance.LLMatcher.grammar_from_json_schema")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_json_as_dict(self, mock_validate, mock_from_schema, llguidance_checker):
"""Test processing guided_json dictionary type."""
mock_from_schema.return_value = "serialized_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_from_schema.assert_called_once()
assert isinstance(request.guided_json, dict) # Verify that the dictionary has been converted to a string
assert grammar == "serialized_grammar"
@patch("llguidance.LLMatcher.grammar_from_json_schema")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_json_object(self, mock_validate, mock_from_schema, llguidance_checker):
"""Test processing guided_json_object."""
mock_from_schema.return_value = "serialized_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_json_object = True
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_from_schema.assert_called_once()
assert request.guided_json_object
assert grammar == "serialized_grammar"
@patch("llguidance.grammar_from")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_regex(self, mock_validate, mock_grammar_from, llguidance_checker):
"""Test processing guided_regex."""
mock_grammar_from.return_value = "serialized_regex_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_regex = "[a-zA-Z]+"
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_grammar_from.assert_called_once_with("regex", "[a-zA-Z]+")
assert grammar == "serialized_regex_grammar"
@patch("llguidance.grammar_from")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_choice(self, mock_validate, mock_grammar_from, llguidance_checker):
"""Test processing guided_choice."""
mock_grammar_from.return_value = "serialized_choice_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_choice = ["option1", "option2"]
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_grammar_from.assert_called_once_with("choice", ["option1", "option2"])
assert grammar == "serialized_choice_grammar"
@patch("llguidance.grammar_from")
@patch("llguidance.LLMatcher.validate_grammar")
def test_serialize_guided_grammar(self, mock_validate, mock_grammar_from, llguidance_checker):
"""Test processing guided_grammar."""
mock_grammar_from.return_value = "serialized_grammar_spec"
mock_validate.return_value = None
request = MockRequest()
request.guided_grammar = "grammar specification"
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_grammar_from.assert_called_once_with("grammar", "grammar specification")
assert grammar == "serialized_grammar_spec"
@patch("llguidance.StructTag")
@patch("llguidance.LLMatcher.grammar_from_json_schema")
def test_serialize_structural_tag(self, mock_from_schema, mock_struct_tag, llguidance_checker):
"""Test processing structural_tag."""
# Configure mock objects
mock_from_schema.return_value = "serialized_schema"
mock_struct_tag.to_grammar.return_value = "serialized_structural_grammar"
struct_tag_instance = MagicMock()
mock_struct_tag.return_value = struct_tag_instance
request = MockRequest()
request.structural_tag = {
"triggers": ["<json>"],
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
}
grammar = llguidance_checker.serialize_guidance_grammar(request)
mock_from_schema.assert_called_once()
mock_struct_tag.assert_called_once()
mock_struct_tag.to_grammar.assert_called_once()
assert grammar == "serialized_structural_grammar"
@patch("llguidance.StructTag")
def test_serialize_structural_tag_missing_trigger(self, mock_struct_tag, llguidance_checker):
"""Test processing structural_tag when a trigger is missing."""
request = MockRequest()
request.structural_tag = {
"triggers": ["<xml>"],
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
}
with pytest.raises(ValueError, match="Trigger .* not found in triggers"):
llguidance_checker.serialize_guidance_grammar(request)
@patch("llguidance.StructTag")
def test_serialize_structural_tag_empty_structures(self, mock_struct_tag, llguidance_checker):
"""Test processing structural_tag when structures are empty."""
request = MockRequest()
request.structural_tag = {"triggers": ["<json>"], "structures": []}
with pytest.raises(ValueError, match="No structural tags found in the grammar spec"):
llguidance_checker.serialize_guidance_grammar(request)
def test_serialize_invalid_grammar_type(self, llguidance_checker):
"""Test processing invalid grammar types."""
request = MockRequest()
# No grammar type set
with pytest.raises(ValueError, match="grammar is not of valid supported types"):
llguidance_checker.serialize_guidance_grammar(request)
@patch("llguidance.LLMatcher.grammar_from_json_schema")
@patch("llguidance.LLMatcher.validate_grammar")
def test_schema_format_valid_json(self, mock_validate, mock_from_schema, llguidance_checker):
"""Test schema_format method processing valid JSON."""
mock_from_schema.return_value = "serialized_grammar"
mock_validate.return_value = None
request = MockRequest()
request.guided_json = '{"type": "object"}'
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
@patch("llguidance.LLMatcher.grammar_from_json_schema")
@patch("llguidance.LLMatcher.validate_grammar")
def test_schema_format_invalid_grammar(self, mock_validate, mock_from_schema, llguidance_checker):
"""Test schema_format method processing invalid grammar."""
mock_from_schema.return_value = "serialized_grammar"
mock_validate.return_value = "Invalid grammar"
request = MockRequest()
request.guided_json = '{"type": "object"}'
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Grammar error: Invalid grammar" in error
@patch("llguidance.LLMatcher.grammar_from_json_schema")
def test_schema_format_json_decode_error(self, mock_from_schema, llguidance_checker):
"""Test schema_format method processing JSON decode error."""
mock_from_schema.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
request = MockRequest()
request.guided_json = "{invalid json}"
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Invalid format for guided decoding" in error
@patch("llguidance.LLMatcher.grammar_from_json_schema")
def test_schema_format_unexpected_error(self, mock_from_schema, llguidance_checker):
"""Test schema_format method processing unexpected errors."""
mock_from_schema.side_effect = Exception("Unexpected error")
request = MockRequest()
request.guided_json = '{"type": "object"}'
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "An unexpected error occurred during schema validation" in error
def test_init_with_disable_whitespace(self, llguidance_checker_with_options):
"""Test setting the disable_any_whitespace option during initialization."""
assert llguidance_checker_with_options.any_whitespace is False
assert llguidance_checker_with_options.disable_additional_properties is True
assert LLGuidanceChecker(disable_any_whitespace=True).any_whitespace is False
assert LLGuidanceChecker(disable_any_whitespace=False).any_whitespace is True
# default check
from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL
assert FD_GUIDANCE_DISABLE_ADDITIONAL
assert LLGuidanceChecker().disable_additional_properties is True
with patch("fastdeploy.model_executor.guided_decoding.guidance_backend.FD_GUIDANCE_DISABLE_ADDITIONAL", False):
assert LLGuidanceChecker().disable_additional_properties is False
@pytest.mark.skipif(not HAS_LLGUIDANCE, reason="llguidance library not installed, skipping actual dependency tests")
class TestLLGuidanceCheckerReal:
"""Test using the actual llguidance library, suitable for development environments."""
def test_serialize_guided_json_string_real(self, llguidance_checker):
"""Test processing guided_json string using the actual library."""
request = MockRequest()
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
grammar = llguidance_checker.serialize_guidance_grammar(request)
# Verify if the returned grammar is a valid string
assert isinstance(grammar, str)
assert len(grammar) > 0
print("grammar", grammar)
def test_serialize_guided_json_dict_real(self, llguidance_checker):
"""Test processing guided_json dictionary using the actual library."""
request = MockRequest()
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert isinstance(request.guided_json, dict)
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_serialize_guided_json_object_real(self, llguidance_checker):
"""Test processing guided_json_object using the actual library."""
request = MockRequest()
request.guided_json_object = True
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert request.guided_json_object
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_serialize_guided_regex_real(self, llguidance_checker):
"""Test processing guided_regex using the actual library."""
request = MockRequest()
request.guided_regex = "[a-zA-Z]+"
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_serialize_guided_choice_real(self, llguidance_checker):
"""Test processing guided_choice using the actual library."""
request = MockRequest()
request.guided_choice = ["option1", "option2"]
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_serialize_guided_grammar_real(self, llguidance_checker):
"""Test processing guided_grammar using the actual library."""
request = MockRequest()
# Use a simple CFG grammar example
request.guided_grammar = """
root ::= greeting name
greeting ::= "Hello" | "Hi"
name ::= "world" | "everyone"
"""
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_serialize_structural_tag_real(self, llguidance_checker):
"""Test processing structural_tag using the actual library."""
request = MockRequest()
request.structural_tag = {
"triggers": ["<json>"],
"structures": [{"begin": "<json>", "schema": {"type": "object"}, "end": "</json>"}],
}
grammar = llguidance_checker.serialize_guidance_grammar(request)
assert isinstance(grammar, str)
assert len(grammar) > 0
def test_schema_format_valid_json_real(self, llguidance_checker):
"""Test schema_format method processing valid JSON using the actual library."""
request = MockRequest()
request.guided_json = '{"type": "object", "properties": {"name": {"type": "string"}}}'
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
assert result_request.guided_json != '{"type": "object", "properties": {"name": {"type": "string"}}}'
def test_schema_format_invalid_json_real(self, llguidance_checker):
"""Test schema_format method processing invalid JSON using the actual library."""
request = MockRequest()
request.guided_json = "{invalid json}"
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Invalid format for guided decoding" in error
def test_whitespace_flexibility_option_real(self):
"""Test the impact of the whitespace flexibility option using the actual library."""
# Create two instances with different configurations
flexible = LLGuidanceChecker(disable_any_whitespace=False)
strict = LLGuidanceChecker(disable_any_whitespace=True)
request_flexible = MockRequest()
request_flexible.guided_json = '{"type": "object"}'
request_strict = MockRequest()
request_strict.guided_json = '{"type": "object"}'
grammar_flexible = flexible.serialize_guidance_grammar(request_flexible)
grammar_strict = strict.serialize_guidance_grammar(request_strict)
print("grammar_flexible", grammar_flexible)
print("grammar_strict", grammar_strict)
# Expect grammars generated by the two configurations to be different
assert grammar_flexible != grammar_strict
def test_schema_format_guided_json_object_real(self, llguidance_checker):
"""Test schema_format processing guided_json_object."""
request = MockRequest()
request.guided_json_object = True
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
def test_schema_format_guided_regex_real(self, llguidance_checker):
"""Test schema_format processing valid regular expressions."""
request = MockRequest()
request.guided_regex = r"[a-zA-Z0-9]+"
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
assert result_request.guided_regex != r"[a-zA-Z0-9]+" # Should be converted to grammar format
def test_schema_format_invalid_guided_regex_real(self, llguidance_checker):
"""Test schema_format processing invalid regular expressions."""
request = MockRequest()
request.guided_regex = r"[" # Invalid regular expression
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Invalid format for guided decoding" in error
def test_schema_format_guided_choice_real(self, llguidance_checker):
"""Test schema_format processing guided_choice."""
request = MockRequest()
request.guided_choice = ["option1", "option2", "option3"]
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
assert result_request.guided_choice != [
"option1",
"option2",
"option3",
] # Should be converted to grammar format
def test_schema_format_guided_grammar_real(self, llguidance_checker):
"""Test schema_format processing guided_grammar."""
request = MockRequest()
# Use the correct grammar format supported by LLGuidance
request.guided_grammar = """
start: number
number: DIGIT+
DIGIT: "0"|"1"|"2"|"3"|"4"|"5"|"6"|"7"|"8"|"9"
"""
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
assert isinstance(result_request.guided_grammar, str)
def test_schema_format_structural_tag_real(self, llguidance_checker):
"""Test schema_format processing structural_tag."""
request = MockRequest()
request.structural_tag = {
"triggers": ["```json"],
"structures": [
{
"begin": "```json",
"schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"end": "```",
}
],
}
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
def test_schema_format_structural_tag_string_real(self, llguidance_checker):
"""Test schema_format processing structural_tag in string format."""
request = MockRequest()
request.structural_tag = json.dumps(
{
"triggers": ["```json"],
"structures": [
{
"begin": "```json",
"schema": {"type": "object", "properties": {"name": {"type": "string"}}},
"end": "```",
}
],
}
)
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
def test_schema_format_structural_tag_invalid_trigger_real(self, llguidance_checker):
"""Test schema_format processing structural_tag with invalid triggers."""
request = MockRequest()
request.structural_tag = {
"triggers": ["```xml"], # Trigger does not match begin
"structures": [
{
"begin": "```json",
"schema": {"type": "object"},
"end": "```",
} # Does not contain any prefix from triggers here
],
}
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Invalid format for guided decoding" in error
def test_schema_format_structural_tag_empty_structures_real(self, llguidance_checker):
"""Test schema_format processing structural_tag with empty structures."""
request = MockRequest()
request.structural_tag = {"triggers": ["```json"], "structures": []} # Empty structure
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "Invalid format for guided decoding" in error
def test_schema_format_json_dict_real(self, llguidance_checker):
"""Test schema_format processing guided_json in dictionary format."""
request = MockRequest()
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
result_request, error = llguidance_checker.schema_format(request)
assert error is None
assert result_request is request
def test_schema_format_disable_additional_properties_real(self):
"""Test schema_format processing disable_additional_properties parameter."""
checker = LLGuidanceChecker(disable_additional_properties=True)
request = MockRequest()
request.guided_json = {"type": "object", "properties": {"name": {"type": "string"}}}
result_request, error = checker.schema_format(request)
assert error is None
assert result_request is request
def test_schema_format_unexpected_error_real(self, monkeypatch, llguidance_checker):
"""Test schema_format processing unexpected errors."""
request = MockRequest()
request.guided_json = '{"type": "object"}'
# Mock unexpected exception
def mock_serialize_grammar(*args, **kwargs):
raise Exception("Unexpected error")
monkeypatch.setattr(llguidance_checker, "serialize_guidance_grammar", mock_serialize_grammar)
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
assert "An unexpected error occurred during schema validation" in error
def test_schema_format_no_valid_grammar_real(self, llguidance_checker):
"""Test schema_format processing requests without valid grammar."""
request = MockRequest()
# No grammar-related attributes set
with pytest.raises(ValueError, match="grammar is not of valid supported types"):
llguidance_checker.serialize_guidance_grammar(request)
result_request, error = llguidance_checker.schema_format(request)
assert error is not None
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,172 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import sys
import unittest
from unittest.mock import MagicMock, patch
# --- Mocking Setup ---
# Prioritize mocking these lazy-loaded modules to facilitate testing in environments where these libraries are not installed.
mock_torch = MagicMock()
mock_llguidance = MagicMock()
mock_llguidance_hf = MagicMock()
mock_llguidance_torch = MagicMock()
mock_torch.__spec__ = MagicMock()
mock_torch.distributed = MagicMock()
sys.modules["torch"] = mock_torch
sys.modules["llguidance"] = mock_llguidance
sys.modules["llguidance.hf"] = mock_llguidance_hf
sys.modules["llguidance.torch"] = mock_llguidance_torch
# Import the module to be tested after the mock setup is complete
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
LLGuidanceProcessor,
)
def MockFDConfig():
"""Create a mock FDConfig object for testing"""
config = MagicMock()
# --- Fix point 1: Explicitly set model as a string to pass HF validation ---
config.model_config.model = "test-model-path"
config.model_config.architectures = [] # Set to empty list to prevent errors when iterating over the Mock
config.model_config.vocab_size = 1000
config.scheduler_config.max_num_seqs = 4
config.structured_outputs_config.disable_any_whitespace = False
# Ensure the backend check logic passes
config.structured_outputs_config.guided_decoding_backend = "guidance"
return config
def MockHFTokenizer():
"""Create a mock Hugging Face Tokenizer object for testing"""
return MagicMock()
class TestLLGuidanceProcessorMocked(unittest.TestCase):
"""
Unit tests for LLGuidanceProcessor using Mock.
This test class is suitable for environments where the llguidance library is not installed.
"""
def setUp(self):
"""Set up a new LLGuidanceProcessor instance for each test case"""
self.mock_matcher = MagicMock()
self.mock_tokenizer = MagicMock()
self.mock_tokenizer.eos_token = 2 # Example EOS token ID
self.processor = LLGuidanceProcessor(
ll_matcher=self.mock_matcher,
ll_tokenizer=self.mock_tokenizer,
serialized_grammar="test_grammar",
vocab_size=1000,
batch_size=4,
enable_thinking=False,
)
def test_init(self):
"""Test the constructor of LLGuidanceProcessor"""
self.assertIs(self.processor.matcher, self.mock_matcher)
self.assertEqual(self.processor.vocab_size, 1000)
self.assertEqual(self.processor.batch_size, 4)
self.assertFalse(self.processor.is_terminated)
@patch("fastdeploy.utils.llm_logger.warning")
def test_check_error_logs_warning_once(self, mock_log_warning):
"""Test that the _check_error method logs a warning when the matcher errors, and only logs it once"""
self.mock_matcher.get_error.return_value = "A test error."
# First call, should log the message
self.processor._check_error()
mock_log_warning.assert_called_once_with("LLGuidance Matcher error: A test error.")
# Second call, should not log repeatedly
self.processor._check_error()
mock_log_warning.assert_called_once()
@patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance.torch")
def test_allocate_token_bitmask(self, mock_backend_torch):
"""
Test the allocation of token bitmask.
Note: We patch the llguidance_torch variable imported in the guidance_backend module here,
instead of the global mock in sys.modules, to resolve inconsistent references caused by LazyLoader.
"""
mock_backend_torch.allocate_token_bitmask.return_value = "fake_bitmask_tensor"
result = self.processor.allocate_token_bitmask()
mock_backend_torch.allocate_token_bitmask.assert_called_once_with(4, 1000)
self.assertEqual(result, "fake_bitmask_tensor")
@patch("fastdeploy.model_executor.guided_decoding.guidance_backend.llguidance.torch")
def test_fill_token_bitmask(self, mock_backend_torch):
"""Test the filling of token bitmask"""
mock_bitmask = MagicMock()
self.processor.fill_token_bitmask(mock_bitmask, idx=2)
mock_backend_torch.fill_next_token_bitmask.assert_called_once_with(self.mock_matcher, mock_bitmask, 2)
self.mock_matcher.get_error.assert_called_once()
def test_reset(self):
"""Test the state reset of the processor"""
self.processor.is_terminated = True
self.processor._printed_error = True
self.mock_matcher.get_error.return_value = ""
self.processor.reset()
self.mock_matcher.reset.assert_called_once()
self.assertFalse(self.processor.is_terminated)
self.assertFalse(self.processor._printed_error)
def test_accept_token_when_terminated(self):
"""Test that accept_token returns False immediately when status is is_terminated"""
self.processor.is_terminated = True
self.assertFalse(self.processor.accept_token(123))
def test_accept_token_when_matcher_stopped(self):
"""Test that accept_token returns False and updates status when the matcher is stopped"""
self.mock_matcher.is_stopped.return_value = True
self.assertTrue(self.processor.accept_token(123))
self.assertFalse(self.processor.is_terminated)
def test_accept_token_is_eos(self):
"""Test the behavior when an EOS token is received"""
self.mock_matcher.is_stopped.return_value = False
self.assertTrue(self.processor.accept_token(self.mock_tokenizer.eos_token))
self.assertTrue(self.processor.is_terminated)
def test_accept_token_consumes_and_succeeds(self):
"""Test successfully consuming a token"""
self.mock_matcher.is_stopped.return_value = False
self.mock_matcher.consume_tokens.return_value = True
self.assertTrue(self.processor.accept_token(123))
self.mock_matcher.consume_tokens.assert_called_once_with([123])
self.mock_matcher.get_error.assert_called_once()
def test_accept_token_consumes_and_fails(self):
"""Test failing to consume a token"""
self.mock_matcher.is_stopped.return_value = False
self.mock_matcher.consume_tokens.return_value = False
self.assertFalse(self.processor.accept_token(123))
self.mock_matcher.consume_tokens.assert_called_once_with([123])
if __name__ == "__main__":
unittest.main()