Files
FastDeploy/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
2025-07-19 23:19:27 +08:00

469 lines
16 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import re
from typing import Any, List, Optional
import paddle
import torch
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
BackendBase,
BaseChecker,
LogitsProcessorBase,
)
from fastdeploy.utils import llm_logger
try:
from xgrammar import (
CompiledGrammar,
Grammar,
GrammarCompiler,
GrammarMatcher,
StructuralTagItem,
TokenizerInfo,
allocate_token_bitmask,
apply_token_bitmask_inplace,
)
except Exception as e:
raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")
class XGrammarProcessor(LogitsProcessorBase):
"""
XGrammar-specific implementation of LogitsProcessorBase.
This processor enforces grammar constraints during token generation using XGrammar.
It manages the grammar matching state and applies token masks to logits.
Attributes:
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
vocab_size (int): Size of the vocabulary
batch_size (int): Batch size for processing
splitwise_role (str): Role for splitwise processing
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens
matcher (GrammarMatcher): Grammar matching engine
"""
def __init__(
self,
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
vocab_size: Optional[int] = None,
batch_size: Optional[int] = None,
splitwise_role: str = "mixed",
):
super().__init__()
self.max_rollback_tokens = 200
self.vocab_size = vocab_size
self.batch_size = batch_size
self.splitwise_role = splitwise_role
self.compiled_grammar = compiled_grammar
self.terminate_without_stop_token = terminate_without_stop_token
self.override_stop_tokens = override_stop_tokens
self.matcher = GrammarMatcher(
compiled_grammar=compiled_grammar,
max_rollback_tokens=self.max_rollback_tokens,
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
)
def allocate_token_bitmask(self) -> torch.Tensor:
"""
Allocate a token bitmask tensor for grammar constraints.
Returns:
torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
"""
return allocate_token_bitmask(self.batch_size, self.vocab_size)
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
"""
Fill the token bitmask with allowed tokens for the given index.
Args:
token_bitmask (torch.Tensor): The token bitmask tensor to fill
idx (int): The batch index to fill the mask for
Returns:
None: Modifies the token_bitmask in-place
"""
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
def apply_token_mask(
self,
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
Args:
logits (paddle.Tensor): The logits tensor to modify
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
Returns:
paddle.Tensor: The modified logits tensor
"""
origin_place = logits.place
origin_dtype = logits.dtype
logits = torch.from_numpy(logits.numpy())
logits = logits.float() # cpu
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
)
return paddle.to_tensor(
logits.numpy(),
dtype=origin_dtype,
place=origin_place,
)
def reset(self) -> None:
"""
Reset the grammar matcher state to initial conditions.
Returns:
None: No return value
"""
self.matcher.reset()
def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.
Args:
token (int): The token ID to validate
Raises:
AssertionError: If token is not allowed by the grammar
"""
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
def is_terminated(self) -> bool:
"""
Check if the grammar matching process has terminated.
Returns:
bool: True if matching has terminated, False otherwise
"""
return self.matcher.is_terminated()
def copy(self) -> "XGrammarProcessor":
"""
Create a deep copy of this processor instance.
Returns:
XGrammarProcessor: A new processor instance with identical state
"""
return XGrammarProcessor(
compiled_grammar=self.compiled_grammar,
terminate_without_stop_token=self.terminate_without_stop_token,
override_stop_tokens=self.override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
)
class XGrammarBackend(BackendBase):
"""
XGrammar-specific implementation of BackendBase.
This backend handles compilation of various schema types (JSON, regex, grammar)
into XGrammar processors. It manages the grammar compiler and tokenizer info.
Attributes:
vocab_size (int): Size of the vocabulary from config
batch_size (int): Maximum batch size from config
any_whitespace (bool): Whether to allow any whitespace in JSON
splitwise_role (str): Role for splitwise processing
grammar_compiler (GrammarCompiler): Grammar compilation engine
"""
def __init__(
self,
fd_config: FDConfig,
**kwargs,
):
super().__init__(fd_config=fd_config)
self.vocab_size = fd_config.model_config.vocab_size
self.batch_size = fd_config.parallel_config.max_num_seqs
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
self.splitwise_role = fd_config.parallel_config.splitwise_role
try:
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
except Exception as e:
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
def _create_processor(
self,
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
) -> XGrammarProcessor:
"""
Create a logits processor instance for the given compiled grammar.
Args:
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
Returns:
XGrammarProcessor: Configured grammar processor instance
"""
return XGrammarProcessor(
compiled_grammar=compiled_grammar,
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
)
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile JSON schema into a grammar processor.
Args:
schemata (str): JSON schema string to compile
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_json_schema(schemata, any_whitespace=self.any_whitespace)
except Exception as e:
llm_logger.error(f"Failed to compile json schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile regex pattern into a grammar processor.
Args:
schemata (str): Regex pattern string to compile
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_regex(schemata)
except Exception as e:
llm_logger.error(f"Failed to compile regex schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile grammar (EBNF) into a grammar processor.
Args:
schemata (str): Grammar string in EBNF format
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
except Exception as e:
llm_logger.error(f"Failed to compile ebnf schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile structural tags into a grammar processor.
Args:
schemata (str): JSON string containing structural tag definitions
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
structural_tag = json.loads(schemata)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
)
for structure in structural_tag["structures"]
]
compiled_grammar = self.grammar_compiler.compile_structural_tag(tags, structural_tag["triggers"])
except Exception as e:
llm_logger.error(f"Failed to compile structural tags schema: {e}")
return None
return self._create_processor(compiled_grammar)
class XGrammarChecker(BaseChecker):
"""
XGrammar-specific implementation of BaseChecker.
This validator checks and formats various schema types (JSON, regex, grammar)
for compatibility with XGrammar before processing.
Attributes:
any_whitespace (bool): Whether to allow any whitespace in JSON
"""
def __init__(self, **kwargs):
super().__init__()
self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
"""
Check if JSON schema contains unsupported features.
Args:
schema (dict[str, Any]): JSON schema to validate
Returns:
bool: True if schema contains unsupported features, False otherwise
"""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
return True
if obj.get("type") == "array" and any(
key in obj
for key in (
"uniqueItems",
"contains",
"minContains",
"maxContains",
)
):
return True
if obj.get("type") == "string" and "format" in obj:
return True
if obj.get("type") == "object" and any(
key in obj
for key in (
"minProperties",
"maxProperties",
"propertyNames",
"patternProperties",
)
):
return True
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def schema_format(self, request: Request):
"""
format schema to backend specific format.
"""
if request.guided_json:
try:
if not isinstance(request.guided_json, str):
guided_json = json.dumps(request.guided_json)
else:
guided_json = request.guided_json
Grammar.from_json_schema(guided_json, any_whitespace=self.any_whitespace)
except RuntimeError as e:
err_msg = f"Invalid JSON format: {guided_json}, error message: {e!s}"
return request, err_msg
if self._unsupported_json_schema(guided_json):
err_msg = f"unsupported JSON schema: {guided_json}"
return request, err_msg
request.guided_json = guided_json
return request, None
elif request.guided_grammar:
# TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
guided_grammar = request.guided_grammar
try:
Grammar.from_ebnf(guided_grammar)
except RuntimeError as e:
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {e!s}"
return request, err_msg
request.guided_grammar = guided_grammar
return request, None
elif request.guided_json_object:
request.guided_json = '{"type": "object"}'
return request, None
elif request.guided_choice:
try:
escaped_choices = (re.sub(r'(["\\])', r"\\\1", c) for c in request.guided_choice)
guided_choice = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
Grammar.from_ebnf(guided_choice)
except RuntimeError as e:
err_msg = f"Invalid choice format: {guided_choice}, error message: {e!s}"
return request, err_msg
request.guided_grammar = guided_choice
return request, None
elif request.structural_tag:
try:
structural_tag = json.loads(request.structural_tag)
tags = [
StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"],
)
for s in structural_tag["structures"]
]
Grammar.from_structural_tag(tags, structural_tag["triggers"])
except RuntimeError as e:
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {e!s}"
return request, err_msg
return request, None
else:
# regex is not format
return request, None