Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,73 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# from fastdeploy.config import FDConfig
__all__ = ['get_guided_backend', 'schema_checker']
def get_guided_backend(
fd_config,
**kwargs,
):
"""
Get the guided decoding backend instance based on configuration.
Args:
fd_config (FDConfig): FastDeploy configuration object containing backend settings
**kwargs: Additional arguments passed to the backend constructor
Returns:
BaseBackend: An instance of the specified guided decoding backend
Raises:
ValueError: If the specified backend is not supported
"""
if fd_config.parallel_config.guided_decoding_backend.lower() == "xgrammar":
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
XGrammarBackend
return XGrammarBackend(
fd_config=fd_config,
**kwargs,
)
else:
raise ValueError(
f"Get unsupported backend {fd_config.parallel_config.guided_decoding_backend},"
f" please check your configuration.")
def schema_checker(backend_name: str, **kwargs):
"""
Get the schema checker instance for the specified backend.
Args:
backend_name (str): Name of the backend (e.g. "xgrammar")
**kwargs: Additional arguments passed to the checker constructor
Returns:
BaseChecker: An instance of the specified schema checker
Raises:
ValueError: If the specified backend is not supported
"""
if backend_name.lower() == "xgrammar":
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
XGrammarChecker
return XGrammarChecker(**kwargs)
else:
raise ValueError(
f"Get unsupported backend {backend_name}, please check your configuration."
)

View File

@@ -0,0 +1,347 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from concurrent.futures import ThreadPoolExecutor
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.utils import llm_logger
class LogitsProcessorBase:
"""
Abstract base class for logits processors in guided decoding.
This class defines the interface for logits processors that modify token probabilities
during generation to enforce schema constraints. Subclasses should implement all
abstract methods to provide specific constraint enforcement logic.
Attributes:
None (all state should be managed by subclasses)
"""
def __init__(self):
pass
def fill_token_bitmask(self, token_bitmask, idx):
"""
Fill the vocabulary mask.
Args:
token_bitmask (tensor): The vocabulary mask tensor.
idx (tensor): The tensor index.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def apply_token_mask(self, logits, token_bitmask):
"""
Apply the vocabulary mask to logits.
Args:
logits (tensor): The logits tensor.
token_bitmask (tensor): The vocabulary mask tensor.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def allocate_token_bitmask(self, batch_size, vocab_size):
"""
Allocate a token bitmask for the given batch size and vocabulary size.
Args:
batch_size (int): The batch size.
vocab_size (int): The vocabulary size.
Returns:
tensor: The allocated token bitmask.
"""
raise NotImplementedError()
def accept_token(self, token):
"""
Accept tokens based on the token bitmask
Args:
token (int): The token id.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def is_terminated(self):
"""
Check if the processor has been terminated.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def reset(self):
"""
Reset the matcher state.
"""
raise NotImplementedError()
def copy(self):
"""
Create a copy of the backend instance.
Returns:
BackendBase: A copy of the backend instance.
"""
raise NotImplementedError()
class BackendBase:
"""
Abstract base class for guided decoding backends.
This class provides the core infrastructure for managing schema processors and
their caching. It handles:
- Processor creation and caching
- Tokenizer initialization
- Thread pool management for async operations
Attributes:
cache (dict): Cache of schema processors
fd_config (FDConfig): FastDeploy configuration
executor (ThreadPoolExecutor): Thread pool for async operations
max_cache_size (int): Maximum number of processors to cache
hf_tokenizer: HuggingFace tokenizer instance
"""
def __init__(self, fd_config: FDConfig):
self.cache = {}
self.fd_config = fd_config
self.executor = ThreadPoolExecutor()
self.max_cache_size = 2048
self.hf_tokenizer = self._get_tokenizer_hf()
def _create_processor(self):
"""
Create a specific logits processor instance.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def _json_processor(self, schemata):
"""
Process JSON schemata.
Args:
schemata (str): The schemata string.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def _regex_processor(self, schemata):
"""
Process regular expression schemata.
Args:
schemata (str): The schemata string.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def _grammar_processor(self, schemata):
"""
Process grammar schemata.
Args:
schemata (str): The schemata string.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def _structural_tag_processor(self, schemata):
"""
Process structural tag schemata.
Args:
schemata (str): The schemata string.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError()
def _unsupported_processor_type(self, key_type, schemata):
"""
Process unsupported type.
Args:
key_type (str): The key type string.
schemata (str): The schemata string.
"""
raise Exception(f"Unsupported processor type {key_type}.")
def _init_logits_processor(
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
"""
init logits processor by type and schemata.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
Returns:
LogitsProcessorBase: Initialized logits processor instance
Raises:
ValueError: If processor type is not supported
"""
key_type, schemata = schemata_key
if key_type == "json":
return self._json_processor(schemata)
elif key_type == "regex":
return self._regex_processor(schemata)
elif key_type == "grammar":
return self._grammar_processor(schemata)
elif key_type == "structural_tag":
return self._structural_tag_processor(schemata)
else:
llm_logger.error(f"Unsupported processor type {key_type}.")
return None
def get_logits_processor(
self,
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
"""
get logits processor by key from cache or create new one.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
Returns:
tuple[LogitsProcessorBase, bool]: Tuple containing:
- LogitsProcessorBase: The logits processor instance
- bool: True if processor was from cache, False if newly created
"""
value = self.cache.get(schemata_key, None)
if value:
return value.copy(), True
value = self.executor.submit(self._init_logits_processor, schemata_key)
return value, False
def _get_tokenizer_hf(self):
"""
Initialize and return a HuggingFace tokenizer instance.
This method handles special cases for Ernie models and falls back to standard
AutoTokenizer for other models. It also ensures fast tokenizer is used when possible.
Returns:
Tokenizer: Initialized HuggingFace tokenizer instance
Raises:
Exception: If tokenizer initialization fails
"""
try:
architectures = self.fd_config.model_config.architectures
if "Ernie4_5_MoeForCausalLM" not in architectures \
and "Ernie4_5_ForCausalLM" not in architectures:
from transformers import AutoTokenizer, PreTrainedTokenizerFast
tokenizer = AutoTokenizer.from_pretrained(
self.fd_config.parallel_config.model_name_or_path,
use_fast=False,
)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = PreTrainedTokenizerFast(
__slow_tokenizer=tokenizer)
else:
from fastdeploy.model_executor.guided_decoding.ernie_tokenizer import \
ErnieBotTokenizer
vocab_file_names = [
"tokenizer.model", "spm.model", "ernie_token_100k.model"
]
for i in range(len(vocab_file_names)):
if os.path.exists(
os.path.join(
self.fd_config.parallel_config.
model_name_or_path, vocab_file_names[i])):
ErnieBotTokenizer.vocab_files_names[
"vocab_file"] = vocab_file_names[i]
break
tokenizer = ErnieBotTokenizer.from_pretrained(
self.fd_config.parallel_config.model_name_or_path)
return tokenizer
except Exception as e:
raise Exception(f"Fail to initialize hf tokenizer: {e}")
def add_cache(self, schemata_key: tuple[str, str],
processor: LogitsProcessorBase) -> None:
"""
add logits processor to cache.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
processor (LogitsProcessorBase): Logits processor instance to cache
Returns:
None: No return value
"""
if len(self.cache) >= self.max_cache_size:
return
self.cache[schemata_key] = processor.copy()
class BaseChecker:
"""
Abstract base class for schema checkers.
This class defines the interface for validating and formatting schemas
before they are used by logits processors. Subclasses should implement
schema-specific validation and formatting logic.
Attributes:
None (all state should be managed by subclasses)
"""
def __init__(self):
pass
def schema_format(self, request: Request):
"""
format schema to backend specific format.
Args:
request (Request): request object.
Returns:
request (Request): request object with formatted schema.
"""
raise NotImplementedError()

View File

@@ -0,0 +1,266 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {},
"tokenizer_file": {},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
class ErnieBotTokenizer(PreTrainedTokenizer):
"""
Construct a ErnieBot tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (`str`):
Path to the vocabulary file.
"""
vocab_files_names = VOCAB_FILES_NAMES
resource_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
bos_token = AddedToken(bos_token,
lstrip=False, rstrip=False) if isinstance(
bos_token, str) else bos_token
eos_token = AddedToken(eos_token,
lstrip=False, rstrip=False) if isinstance(
eos_token, str) else eos_token
unk_token = AddedToken(unk_token,
lstrip=False, rstrip=False) if isinstance(
unk_token, str) else unk_token
pad_token = AddedToken(pad_token,
lstrip=False, rstrip=False) if isinstance(
pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
# for eb35 reader
self.bos_id = self.bos_token_id
self.eos_id = self.eos_token_id
self.sep_id = self.sep_token_id
self.pad_id = self.pad_token_id
self.unk_id = self.unk_token_id
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {
self.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
vocab.update(self.added_tokens_encoder)
return vocab
def tokenize(self, text):
"""Returns a tokenized string."""
return self._tokenize(text)
def _tokenize(self, text):
"""Returns a tokenized string."""
return self.sp_model.encode(text, out_type=str)
def decode(self,
tokens,
skip_special_tokens=False,
clean_up_tokenization_spaces=False):
"""Returns a tokenized string."""
return self.sp_model.decode(tokens)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self,
save_directory,
filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") +
VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file, )
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
build inputs with special tokens
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output

View File

@@ -0,0 +1,457 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import re
from typing import Any, List, Optional
import paddle
import torch
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
BackendBase, BaseChecker, LogitsProcessorBase)
from fastdeploy.utils import llm_logger
try:
from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
GrammarMatcher, StructuralTagItem, TokenizerInfo,
allocate_token_bitmask, apply_token_bitmask_inplace)
except Exception as e:
raise Exception(
f"import XGrammar failed, please check your environment:\n\t {e}")
class XGrammarProcessor(LogitsProcessorBase):
"""
XGrammar-specific implementation of LogitsProcessorBase.
This processor enforces grammar constraints during token generation using XGrammar.
It manages the grammar matching state and applies token masks to logits.
Attributes:
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
vocab_size (int): Size of the vocabulary
batch_size (int): Batch size for processing
splitwise_role (str): Role for splitwise processing
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens
matcher (GrammarMatcher): Grammar matching engine
"""
def __init__(
self,
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
vocab_size: Optional[int] = None,
batch_size: Optional[int] = None,
splitwise_role: str = "mixed",
):
super().__init__()
self.max_rollback_tokens = 200
self.vocab_size = vocab_size
self.batch_size = batch_size
self.splitwise_role = splitwise_role
self.compiled_grammar = compiled_grammar
self.terminate_without_stop_token = terminate_without_stop_token
self.override_stop_tokens = override_stop_tokens
self.matcher = GrammarMatcher(
compiled_grammar=compiled_grammar,
max_rollback_tokens=self.max_rollback_tokens,
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
)
def allocate_token_bitmask(self) -> torch.Tensor:
"""
Allocate a token bitmask tensor for grammar constraints.
Returns:
torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
"""
return allocate_token_bitmask(self.batch_size, self.vocab_size)
def fill_token_bitmask(self, token_bitmask: torch.Tensor,
idx: int) -> None:
"""
Fill the token bitmask with allowed tokens for the given index.
Args:
token_bitmask (torch.Tensor): The token bitmask tensor to fill
idx (int): The batch index to fill the mask for
Returns:
None: Modifies the token_bitmask in-place
"""
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
def apply_token_mask(
self,
logits: paddle.Tensor,
token_bitmask: torch.Tensor,
indices: Optional[List[int]] = None,
) -> paddle.Tensor:
"""
Apply the token mask to the logits, modifying probabilities of invalid tokens.
Args:
logits (paddle.Tensor): The logits tensor to modify
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
Returns:
paddle.Tensor: The modified logits tensor
"""
origin_place = logits.place
origin_dtype = logits.dtype
logits = torch.from_numpy(logits.numpy())
logits = logits.float() # cpu
apply_token_bitmask_inplace(
logits=logits,
bitmask=token_bitmask.to(logits.device, non_blocking=True),
indices=indices,
)
return paddle.to_tensor(
logits.numpy(),
dtype=origin_dtype,
place=origin_place,
)
def reset(self) -> None:
"""
Reset the grammar matcher state to initial conditions.
Returns:
None: No return value
"""
self.matcher.reset()
def accept_token(self, token: int) -> None:
"""
Validate and accept a generated token against the grammar constraints.
Args:
token (int): The token ID to validate
Raises:
AssertionError: If token is not allowed by the grammar
"""
assert self.matcher.accept_token(
token), f"Failed to accept token {token}"
def is_terminated(self) -> bool:
"""
Check if the grammar matching process has terminated.
Returns:
bool: True if matching has terminated, False otherwise
"""
return self.matcher.is_terminated()
def copy(self) -> "XGrammarProcessor":
"""
Create a deep copy of this processor instance.
Returns:
XGrammarProcessor: A new processor instance with identical state
"""
return XGrammarProcessor(
compiled_grammar=self.compiled_grammar,
terminate_without_stop_token=self.terminate_without_stop_token,
override_stop_tokens=self.override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
)
class XGrammarBackend(BackendBase):
"""
XGrammar-specific implementation of BackendBase.
This backend handles compilation of various schema types (JSON, regex, grammar)
into XGrammar processors. It manages the grammar compiler and tokenizer info.
Attributes:
vocab_size (int): Size of the vocabulary from config
batch_size (int): Maximum batch size from config
any_whitespace (bool): Whether to allow any whitespace in JSON
splitwise_role (str): Role for splitwise processing
grammar_compiler (GrammarCompiler): Grammar compilation engine
"""
def __init__(
self,
fd_config: FDConfig,
**kwargs,
):
super().__init__(fd_config=fd_config)
self.vocab_size = fd_config.model_config.vocab_size
self.batch_size = fd_config.parallel_config.max_num_seqs
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
self.splitwise_role = fd_config.parallel_config.splitwise_role
try:
tokenizer_info = TokenizerInfo.from_huggingface(
self.hf_tokenizer, vocab_size=self.vocab_size)
self.grammar_compiler = GrammarCompiler(
tokenizer_info=tokenizer_info)
except Exception as e:
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
def _create_processor(
self,
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
) -> XGrammarProcessor:
"""
Create a logits processor instance for the given compiled grammar.
Args:
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
Returns:
XGrammarProcessor: Configured grammar processor instance
"""
return XGrammarProcessor(
compiled_grammar=compiled_grammar,
terminate_without_stop_token=terminate_without_stop_token,
override_stop_tokens=override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
splitwise_role=self.splitwise_role,
)
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile JSON schema into a grammar processor.
Args:
schemata (str): JSON schema string to compile
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_json_schema(
schemata, any_whitespace=self.any_whitespace)
except Exception as e:
llm_logger.error(f"Failed to compile json schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile regex pattern into a grammar processor.
Args:
schemata (str): Regex pattern string to compile
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_regex(schemata)
except Exception as e:
llm_logger.error(f"Failed to compile regex schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile grammar (EBNF) into a grammar processor.
Args:
schemata (str): Grammar string in EBNF format
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
except Exception as e:
llm_logger.error(f"Failed to compile ebnf schema: {e}")
return None
return self._create_processor(compiled_grammar)
def _structural_tag_processor(
self, schemata: str) -> Optional[XGrammarProcessor]:
"""
Compile structural tags into a grammar processor.
Args:
schemata (str): JSON string containing structural tag definitions
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
"""
try:
structural_tag = json.loads(schemata)
tags = [
StructuralTagItem(
begin=structure["begin"],
schema=json.dumps(structure["schema"]),
end=structure["end"],
) for structure in structural_tag["structures"]
]
compiled_grammar = self.grammar_compiler.compile_structural_tag(
tags, structural_tag["triggers"])
except Exception as e:
llm_logger.error(f"Failed to compile structural tags schema: {e}")
return None
return self._create_processor(compiled_grammar)
class XGrammarChecker(BaseChecker):
"""
XGrammar-specific implementation of BaseChecker.
This validator checks and formats various schema types (JSON, regex, grammar)
for compatibility with XGrammar before processing.
Attributes:
any_whitespace (bool): Whether to allow any whitespace in JSON
"""
def __init__(self, **kwargs):
super().__init__()
self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
"""
Check if JSON schema contains unsupported features.
Args:
schema (dict[str, Any]): JSON schema to validate
Returns:
bool: True if schema contains unsupported features, False otherwise
"""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
if obj.get("type") in ("integer", "number") and ("multipleOf"
in obj):
return True
if obj.get("type") == "array" and any(
key in obj for key in ("uniqueItems", "contains",
"minContains", "maxContains")):
return True
if obj.get("type") == "string" and "format" in obj:
return True
if obj.get("type") == "object" and any(
key in obj
for key in ("minProperties", "maxProperties",
"propertyNames", "patternProperties")):
return True
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def schema_format(self, request: Request):
"""
format schema to backend specific format.
"""
if request.guided_json:
try:
if not isinstance(request.guided_json, str):
guided_json = json.dumps(request.guided_json)
else:
guided_json = request.guided_json
Grammar.from_json_schema(guided_json,
any_whitespace=self.any_whitespace)
except RuntimeError as e:
err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
return request, err_msg
if self._unsupported_json_schema(guided_json):
err_msg = f"unsupported JSON schema: {guided_json}"
return request, err_msg
request.guided_json = guided_json
return request, None
elif request.guided_grammar:
# TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
guided_grammar = request.guided_grammar
try:
Grammar.from_ebnf(guided_grammar)
except RuntimeError as e:
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
return request, err_msg
request.guided_grammar = guided_grammar
return request, None
elif request.guided_json_object:
request.guided_json = '{"type": "object"}'
return request, None
elif request.guided_choice:
try:
escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
for c in request.guided_choice)
guided_choice = ('root ::= ' +
' | '.join(f'"{c}"' for c in escaped_choices))
Grammar.from_ebnf(guided_choice)
except RuntimeError as e:
err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
return request, err_msg
request.guided_grammar = guided_choice
return request, None
elif request.structural_tag:
try:
structural_tag = json.loads(request.structural_tag)
tags = [
StructuralTagItem(
begin=s["begin"],
schema=json.dumps(s["schema"]),
end=s["end"],
) for s in structural_tag["structures"]
]
Grammar.from_structural_tag(tags, structural_tag["triggers"])
except RuntimeError as e:
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
return request, err_msg
return request, None
else:
# regex is not format
return request, None