mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
73
fastdeploy/model_executor/guided_decoding/__init__.py
Normal file
73
fastdeploy/model_executor/guided_decoding/__init__.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
# from fastdeploy.config import FDConfig
|
||||
|
||||
__all__ = ['get_guided_backend', 'schema_checker']
|
||||
|
||||
|
||||
def get_guided_backend(
|
||||
fd_config,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Get the guided decoding backend instance based on configuration.
|
||||
|
||||
Args:
|
||||
fd_config (FDConfig): FastDeploy configuration object containing backend settings
|
||||
**kwargs: Additional arguments passed to the backend constructor
|
||||
|
||||
Returns:
|
||||
BaseBackend: An instance of the specified guided decoding backend
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified backend is not supported
|
||||
"""
|
||||
if fd_config.parallel_config.guided_decoding_backend.lower() == "xgrammar":
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
|
||||
XGrammarBackend
|
||||
return XGrammarBackend(
|
||||
fd_config=fd_config,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Get unsupported backend {fd_config.parallel_config.guided_decoding_backend},"
|
||||
f" please check your configuration.")
|
||||
|
||||
|
||||
def schema_checker(backend_name: str, **kwargs):
|
||||
"""
|
||||
Get the schema checker instance for the specified backend.
|
||||
|
||||
Args:
|
||||
backend_name (str): Name of the backend (e.g. "xgrammar")
|
||||
**kwargs: Additional arguments passed to the checker constructor
|
||||
|
||||
Returns:
|
||||
BaseChecker: An instance of the specified schema checker
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified backend is not supported
|
||||
"""
|
||||
if backend_name.lower() == "xgrammar":
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
|
||||
XGrammarChecker
|
||||
return XGrammarChecker(**kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Get unsupported backend {backend_name}, please check your configuration."
|
||||
)
|
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class LogitsProcessorBase:
|
||||
"""
|
||||
Abstract base class for logits processors in guided decoding.
|
||||
|
||||
This class defines the interface for logits processors that modify token probabilities
|
||||
during generation to enforce schema constraints. Subclasses should implement all
|
||||
abstract methods to provide specific constraint enforcement logic.
|
||||
|
||||
Attributes:
|
||||
None (all state should be managed by subclasses)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask, idx):
|
||||
"""
|
||||
Fill the vocabulary mask.
|
||||
|
||||
Args:
|
||||
token_bitmask (tensor): The vocabulary mask tensor.
|
||||
idx (tensor): The tensor index.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_token_mask(self, logits, token_bitmask):
|
||||
"""
|
||||
Apply the vocabulary mask to logits.
|
||||
|
||||
Args:
|
||||
logits (tensor): The logits tensor.
|
||||
token_bitmask (tensor): The vocabulary mask tensor.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def allocate_token_bitmask(self, batch_size, vocab_size):
|
||||
"""
|
||||
Allocate a token bitmask for the given batch size and vocabulary size.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size.
|
||||
vocab_size (int): The vocabulary size.
|
||||
|
||||
Returns:
|
||||
tensor: The allocated token bitmask.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def accept_token(self, token):
|
||||
"""
|
||||
Accept tokens based on the token bitmask
|
||||
|
||||
Args:
|
||||
token (int): The token id.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_terminated(self):
|
||||
"""
|
||||
Check if the processor has been terminated.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the matcher state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Create a copy of the backend instance.
|
||||
|
||||
Returns:
|
||||
BackendBase: A copy of the backend instance.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BackendBase:
|
||||
"""
|
||||
Abstract base class for guided decoding backends.
|
||||
|
||||
This class provides the core infrastructure for managing schema processors and
|
||||
their caching. It handles:
|
||||
- Processor creation and caching
|
||||
- Tokenizer initialization
|
||||
- Thread pool management for async operations
|
||||
|
||||
Attributes:
|
||||
cache (dict): Cache of schema processors
|
||||
fd_config (FDConfig): FastDeploy configuration
|
||||
executor (ThreadPoolExecutor): Thread pool for async operations
|
||||
max_cache_size (int): Maximum number of processors to cache
|
||||
hf_tokenizer: HuggingFace tokenizer instance
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.cache = {}
|
||||
self.fd_config = fd_config
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.max_cache_size = 2048
|
||||
|
||||
self.hf_tokenizer = self._get_tokenizer_hf()
|
||||
|
||||
def _create_processor(self):
|
||||
"""
|
||||
Create a specific logits processor instance.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _json_processor(self, schemata):
|
||||
"""
|
||||
Process JSON schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _regex_processor(self, schemata):
|
||||
"""
|
||||
Process regular expression schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _grammar_processor(self, schemata):
|
||||
"""
|
||||
Process grammar schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _structural_tag_processor(self, schemata):
|
||||
"""
|
||||
Process structural tag schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _unsupported_processor_type(self, key_type, schemata):
|
||||
"""
|
||||
Process unsupported type.
|
||||
|
||||
Args:
|
||||
key_type (str): The key type string.
|
||||
schemata (str): The schemata string.
|
||||
"""
|
||||
raise Exception(f"Unsupported processor type {key_type}.")
|
||||
|
||||
def _init_logits_processor(
|
||||
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
||||
"""
|
||||
init logits processor by type and schemata.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
|
||||
Returns:
|
||||
LogitsProcessorBase: Initialized logits processor instance
|
||||
|
||||
Raises:
|
||||
ValueError: If processor type is not supported
|
||||
"""
|
||||
key_type, schemata = schemata_key
|
||||
if key_type == "json":
|
||||
return self._json_processor(schemata)
|
||||
elif key_type == "regex":
|
||||
return self._regex_processor(schemata)
|
||||
elif key_type == "grammar":
|
||||
return self._grammar_processor(schemata)
|
||||
elif key_type == "structural_tag":
|
||||
return self._structural_tag_processor(schemata)
|
||||
else:
|
||||
llm_logger.error(f"Unsupported processor type {key_type}.")
|
||||
return None
|
||||
|
||||
def get_logits_processor(
|
||||
self,
|
||||
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
||||
"""
|
||||
get logits processor by key from cache or create new one.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
|
||||
Returns:
|
||||
tuple[LogitsProcessorBase, bool]: Tuple containing:
|
||||
- LogitsProcessorBase: The logits processor instance
|
||||
- bool: True if processor was from cache, False if newly created
|
||||
"""
|
||||
value = self.cache.get(schemata_key, None)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_logits_processor, schemata_key)
|
||||
return value, False
|
||||
|
||||
def _get_tokenizer_hf(self):
|
||||
"""
|
||||
Initialize and return a HuggingFace tokenizer instance.
|
||||
|
||||
This method handles special cases for Ernie models and falls back to standard
|
||||
AutoTokenizer for other models. It also ensures fast tokenizer is used when possible.
|
||||
|
||||
Returns:
|
||||
Tokenizer: Initialized HuggingFace tokenizer instance
|
||||
|
||||
Raises:
|
||||
Exception: If tokenizer initialization fails
|
||||
"""
|
||||
try:
|
||||
architectures = self.fd_config.model_config.architectures
|
||||
if "Ernie4_5_MoeForCausalLM" not in architectures \
|
||||
and "Ernie4_5_ForCausalLM" not in architectures:
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
__slow_tokenizer=tokenizer)
|
||||
else:
|
||||
from fastdeploy.model_executor.guided_decoding.ernie_tokenizer import \
|
||||
ErnieBotTokenizer
|
||||
|
||||
vocab_file_names = [
|
||||
"tokenizer.model", "spm.model", "ernie_token_100k.model"
|
||||
]
|
||||
for i in range(len(vocab_file_names)):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
self.fd_config.parallel_config.
|
||||
model_name_or_path, vocab_file_names[i])):
|
||||
ErnieBotTokenizer.vocab_files_names[
|
||||
"vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path)
|
||||
|
||||
return tokenizer
|
||||
except Exception as e:
|
||||
raise Exception(f"Fail to initialize hf tokenizer: {e}")
|
||||
|
||||
def add_cache(self, schemata_key: tuple[str, str],
|
||||
processor: LogitsProcessorBase) -> None:
|
||||
"""
|
||||
add logits processor to cache.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
processor (LogitsProcessorBase): Logits processor instance to cache
|
||||
|
||||
Returns:
|
||||
None: No return value
|
||||
"""
|
||||
if len(self.cache) >= self.max_cache_size:
|
||||
return
|
||||
self.cache[schemata_key] = processor.copy()
|
||||
|
||||
|
||||
class BaseChecker:
|
||||
"""
|
||||
Abstract base class for schema checkers.
|
||||
|
||||
This class defines the interface for validating and formatting schemas
|
||||
before they are used by logits processors. Subclasses should implement
|
||||
schema-specific validation and formatting logic.
|
||||
|
||||
Attributes:
|
||||
None (all state should be managed by subclasses)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def schema_format(self, request: Request):
|
||||
"""
|
||||
format schema to backend specific format.
|
||||
Args:
|
||||
request (Request): request object.
|
||||
|
||||
Returns:
|
||||
request (Request): request object with formatted schema.
|
||||
"""
|
||||
raise NotImplementedError()
|
266
fastdeploy/model_executor/guided_decoding/ernie_tokenizer.py
Normal file
266
fastdeploy/model_executor/guided_decoding/ernie_tokenizer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {},
|
||||
"tokenizer_file": {},
|
||||
}
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
||||
|
||||
|
||||
class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a ErnieBot tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
resource_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="<pad>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
bos_token = AddedToken(bos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
pad_token, str) else pad_token
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
add_bos_token=add_bos_token,
|
||||
add_eos_token=add_eos_token,
|
||||
sp_model_kwargs=self.sp_model_kwargs,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# for eb35 reader
|
||||
self.bos_id = self.bos_token_id
|
||||
self.eos_id = self.eos_token_id
|
||||
self.sep_id = self.sep_token_id
|
||||
self.pad_id = self.pad_token_id
|
||||
self.unk_id = self.unk_token_id
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
return state
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {
|
||||
self.convert_ids_to_tokens(i): i
|
||||
for i in range(self.vocab_size)
|
||||
}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self._tokenize(text)
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def decode(self,
|
||||
tokens,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for i, token in enumerate(tokens):
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special and i != 0:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self,
|
||||
save_directory,
|
||||
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") +
|
||||
VOCAB_FILES_NAMES["vocab_file"])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file, )
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
build inputs with special tokens
|
||||
"""
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = bos_token_id + token_ids_0 + eos_token_id
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + bos_token_id + token_ids_1 + eos_token_id
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True)
|
||||
|
||||
bos_token_id = [1] if self.add_bos_token else []
|
||||
eos_token_id = [1] if self.add_eos_token else []
|
||||
|
||||
if token_ids_1 is None:
|
||||
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
||||
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
|
||||
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
||||
sequence pair mask has the following format:
|
||||
```
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
```
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
Returns:
|
||||
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
||||
"""
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
||||
|
||||
return output
|
457
fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
Normal file
457
fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
import torch
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
BackendBase, BaseChecker, LogitsProcessorBase)
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
|
||||
GrammarMatcher, StructuralTagItem, TokenizerInfo,
|
||||
allocate_token_bitmask, apply_token_bitmask_inplace)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
|
||||
|
||||
class XGrammarProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
XGrammar-specific implementation of LogitsProcessorBase.
|
||||
|
||||
This processor enforces grammar constraints during token generation using XGrammar.
|
||||
It manages the grammar matching state and applies token masks to logits.
|
||||
|
||||
Attributes:
|
||||
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
||||
vocab_size (int): Size of the vocabulary
|
||||
batch_size (int): Batch size for processing
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
||||
matcher (GrammarMatcher): Grammar matching engine
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_grammar: CompiledGrammar,
|
||||
terminate_without_stop_token: bool = False,
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
splitwise_role: str = "mixed",
|
||||
):
|
||||
super().__init__()
|
||||
self.max_rollback_tokens = 200
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.splitwise_role = splitwise_role
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.terminate_without_stop_token = terminate_without_stop_token
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
self.matcher = GrammarMatcher(
|
||||
compiled_grammar=compiled_grammar,
|
||||
max_rollback_tokens=self.max_rollback_tokens,
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self) -> torch.Tensor:
|
||||
"""
|
||||
Allocate a token bitmask tensor for grammar constraints.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
|
||||
"""
|
||||
return allocate_token_bitmask(self.batch_size, self.vocab_size)
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor,
|
||||
idx: int) -> None:
|
||||
"""
|
||||
Fill the token bitmask with allowed tokens for the given index.
|
||||
|
||||
Args:
|
||||
token_bitmask (torch.Tensor): The token bitmask tensor to fill
|
||||
idx (int): The batch index to fill the mask for
|
||||
|
||||
Returns:
|
||||
None: Modifies the token_bitmask in-place
|
||||
"""
|
||||
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
token_bitmask: torch.Tensor,
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
||||
|
||||
Args:
|
||||
logits (paddle.Tensor): The logits tensor to modify
|
||||
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
|
||||
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The modified logits tensor
|
||||
"""
|
||||
origin_place = logits.place
|
||||
origin_dtype = logits.dtype
|
||||
logits = torch.from_numpy(logits.numpy())
|
||||
|
||||
logits = logits.float() # cpu
|
||||
apply_token_bitmask_inplace(
|
||||
logits=logits,
|
||||
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
return paddle.to_tensor(
|
||||
logits.numpy(),
|
||||
dtype=origin_dtype,
|
||||
place=origin_place,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the grammar matcher state to initial conditions.
|
||||
|
||||
Returns:
|
||||
None: No return value
|
||||
"""
|
||||
self.matcher.reset()
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Validate and accept a generated token against the grammar constraints.
|
||||
|
||||
Args:
|
||||
token (int): The token ID to validate
|
||||
|
||||
Raises:
|
||||
AssertionError: If token is not allowed by the grammar
|
||||
"""
|
||||
assert self.matcher.accept_token(
|
||||
token), f"Failed to accept token {token}"
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Check if the grammar matching process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if matching has terminated, False otherwise
|
||||
"""
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def copy(self) -> "XGrammarProcessor":
|
||||
"""
|
||||
Create a deep copy of this processor instance.
|
||||
|
||||
Returns:
|
||||
XGrammarProcessor: A new processor instance with identical state
|
||||
"""
|
||||
return XGrammarProcessor(
|
||||
compiled_grammar=self.compiled_grammar,
|
||||
terminate_without_stop_token=self.terminate_without_stop_token,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
)
|
||||
|
||||
|
||||
class XGrammarBackend(BackendBase):
|
||||
"""
|
||||
XGrammar-specific implementation of BackendBase.
|
||||
|
||||
This backend handles compilation of various schema types (JSON, regex, grammar)
|
||||
into XGrammar processors. It manages the grammar compiler and tokenizer info.
|
||||
|
||||
Attributes:
|
||||
vocab_size (int): Size of the vocabulary from config
|
||||
batch_size (int): Maximum batch size from config
|
||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.vocab_size = fd_config.model_config.vocab_size
|
||||
self.batch_size = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(
|
||||
tokenizer_info=tokenizer_info)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
||||
|
||||
def _create_processor(
|
||||
self,
|
||||
compiled_grammar: CompiledGrammar,
|
||||
terminate_without_stop_token: bool = False,
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
) -> XGrammarProcessor:
|
||||
"""
|
||||
Create a logits processor instance for the given compiled grammar.
|
||||
|
||||
Args:
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
||||
|
||||
Returns:
|
||||
XGrammarProcessor: Configured grammar processor instance
|
||||
"""
|
||||
return XGrammarProcessor(
|
||||
compiled_grammar=compiled_grammar,
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
)
|
||||
|
||||
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile JSON schema into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): JSON schema string to compile
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(
|
||||
schemata, any_whitespace=self.any_whitespace)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile json schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile regex pattern into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Regex pattern string to compile
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_regex(schemata)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile regex schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile grammar (EBNF) into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Grammar string in EBNF format
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile ebnf schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _structural_tag_processor(
|
||||
self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile structural tags into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): JSON string containing structural tag definitions
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
structural_tag = json.loads(schemata)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
) for structure in structural_tag["structures"]
|
||||
]
|
||||
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"])
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile structural tags schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
|
||||
class XGrammarChecker(BaseChecker):
|
||||
"""
|
||||
XGrammar-specific implementation of BaseChecker.
|
||||
|
||||
This validator checks and formats various schema types (JSON, regex, grammar)
|
||||
for compatibility with XGrammar before processing.
|
||||
|
||||
Attributes:
|
||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
|
||||
|
||||
def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if JSON schema contains unsupported features.
|
||||
|
||||
Args:
|
||||
schema (dict[str, Any]): JSON schema to validate
|
||||
|
||||
Returns:
|
||||
bool: True if schema contains unsupported features, False otherwise
|
||||
"""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf"
|
||||
in obj):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj
|
||||
for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
return True
|
||||
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
def schema_format(self, request: Request):
|
||||
"""
|
||||
format schema to backend specific format.
|
||||
"""
|
||||
if request.guided_json:
|
||||
try:
|
||||
if not isinstance(request.guided_json, str):
|
||||
guided_json = json.dumps(request.guided_json)
|
||||
else:
|
||||
guided_json = request.guided_json
|
||||
|
||||
Grammar.from_json_schema(guided_json,
|
||||
any_whitespace=self.any_whitespace)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
|
||||
if self._unsupported_json_schema(guided_json):
|
||||
err_msg = f"unsupported JSON schema: {guided_json}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_json = guided_json
|
||||
return request, None
|
||||
elif request.guided_grammar:
|
||||
# TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
|
||||
guided_grammar = request.guided_grammar
|
||||
try:
|
||||
Grammar.from_ebnf(guided_grammar)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
request.guided_grammar = guided_grammar
|
||||
return request, None
|
||||
elif request.guided_json_object:
|
||||
request.guided_json = '{"type": "object"}'
|
||||
return request, None
|
||||
elif request.guided_choice:
|
||||
try:
|
||||
escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
|
||||
for c in request.guided_choice)
|
||||
guided_choice = ('root ::= ' +
|
||||
' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
|
||||
Grammar.from_ebnf(guided_choice)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_grammar = guided_choice
|
||||
return request, None
|
||||
elif request.structural_tag:
|
||||
try:
|
||||
structural_tag = json.loads(request.structural_tag)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in structural_tag["structures"]
|
||||
]
|
||||
Grammar.from_structural_tag(tags, structural_tag["triggers"])
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
return request, None
|
||||
else:
|
||||
# regex is not format
|
||||
return request, None
|
Reference in New Issue
Block a user