mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
457
fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
Normal file
457
fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import paddle
|
||||
import torch
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
BackendBase, BaseChecker, LogitsProcessorBase)
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
|
||||
GrammarMatcher, StructuralTagItem, TokenizerInfo,
|
||||
allocate_token_bitmask, apply_token_bitmask_inplace)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
|
||||
|
||||
class XGrammarProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
XGrammar-specific implementation of LogitsProcessorBase.
|
||||
|
||||
This processor enforces grammar constraints during token generation using XGrammar.
|
||||
It manages the grammar matching state and applies token masks to logits.
|
||||
|
||||
Attributes:
|
||||
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
||||
vocab_size (int): Size of the vocabulary
|
||||
batch_size (int): Batch size for processing
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
||||
matcher (GrammarMatcher): Grammar matching engine
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_grammar: CompiledGrammar,
|
||||
terminate_without_stop_token: bool = False,
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
splitwise_role: str = "mixed",
|
||||
):
|
||||
super().__init__()
|
||||
self.max_rollback_tokens = 200
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.splitwise_role = splitwise_role
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.terminate_without_stop_token = terminate_without_stop_token
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
|
||||
self.matcher = GrammarMatcher(
|
||||
compiled_grammar=compiled_grammar,
|
||||
max_rollback_tokens=self.max_rollback_tokens,
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self) -> torch.Tensor:
|
||||
"""
|
||||
Allocate a token bitmask tensor for grammar constraints.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
|
||||
"""
|
||||
return allocate_token_bitmask(self.batch_size, self.vocab_size)
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor,
|
||||
idx: int) -> None:
|
||||
"""
|
||||
Fill the token bitmask with allowed tokens for the given index.
|
||||
|
||||
Args:
|
||||
token_bitmask (torch.Tensor): The token bitmask tensor to fill
|
||||
idx (int): The batch index to fill the mask for
|
||||
|
||||
Returns:
|
||||
None: Modifies the token_bitmask in-place
|
||||
"""
|
||||
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
|
||||
|
||||
def apply_token_mask(
|
||||
self,
|
||||
logits: paddle.Tensor,
|
||||
token_bitmask: torch.Tensor,
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
||||
|
||||
Args:
|
||||
logits (paddle.Tensor): The logits tensor to modify
|
||||
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
|
||||
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: The modified logits tensor
|
||||
"""
|
||||
origin_place = logits.place
|
||||
origin_dtype = logits.dtype
|
||||
logits = torch.from_numpy(logits.numpy())
|
||||
|
||||
logits = logits.float() # cpu
|
||||
apply_token_bitmask_inplace(
|
||||
logits=logits,
|
||||
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
return paddle.to_tensor(
|
||||
logits.numpy(),
|
||||
dtype=origin_dtype,
|
||||
place=origin_place,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the grammar matcher state to initial conditions.
|
||||
|
||||
Returns:
|
||||
None: No return value
|
||||
"""
|
||||
self.matcher.reset()
|
||||
|
||||
def accept_token(self, token: int) -> None:
|
||||
"""
|
||||
Validate and accept a generated token against the grammar constraints.
|
||||
|
||||
Args:
|
||||
token (int): The token ID to validate
|
||||
|
||||
Raises:
|
||||
AssertionError: If token is not allowed by the grammar
|
||||
"""
|
||||
assert self.matcher.accept_token(
|
||||
token), f"Failed to accept token {token}"
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Check if the grammar matching process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if matching has terminated, False otherwise
|
||||
"""
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def copy(self) -> "XGrammarProcessor":
|
||||
"""
|
||||
Create a deep copy of this processor instance.
|
||||
|
||||
Returns:
|
||||
XGrammarProcessor: A new processor instance with identical state
|
||||
"""
|
||||
return XGrammarProcessor(
|
||||
compiled_grammar=self.compiled_grammar,
|
||||
terminate_without_stop_token=self.terminate_without_stop_token,
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
)
|
||||
|
||||
|
||||
class XGrammarBackend(BackendBase):
|
||||
"""
|
||||
XGrammar-specific implementation of BackendBase.
|
||||
|
||||
This backend handles compilation of various schema types (JSON, regex, grammar)
|
||||
into XGrammar processors. It manages the grammar compiler and tokenizer info.
|
||||
|
||||
Attributes:
|
||||
vocab_size (int): Size of the vocabulary from config
|
||||
batch_size (int): Maximum batch size from config
|
||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.vocab_size = fd_config.model_config.vocab_size
|
||||
self.batch_size = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(
|
||||
tokenizer_info=tokenizer_info)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
||||
|
||||
def _create_processor(
|
||||
self,
|
||||
compiled_grammar: CompiledGrammar,
|
||||
terminate_without_stop_token: bool = False,
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
) -> XGrammarProcessor:
|
||||
"""
|
||||
Create a logits processor instance for the given compiled grammar.
|
||||
|
||||
Args:
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
||||
|
||||
Returns:
|
||||
XGrammarProcessor: Configured grammar processor instance
|
||||
"""
|
||||
return XGrammarProcessor(
|
||||
compiled_grammar=compiled_grammar,
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
)
|
||||
|
||||
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile JSON schema into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): JSON schema string to compile
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(
|
||||
schemata, any_whitespace=self.any_whitespace)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile json schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile regex pattern into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Regex pattern string to compile
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_regex(schemata)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile regex schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile grammar (EBNF) into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Grammar string in EBNF format
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile ebnf schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _structural_tag_processor(
|
||||
self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile structural tags into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): JSON string containing structural tag definitions
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
structural_tag = json.loads(schemata)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
) for structure in structural_tag["structures"]
|
||||
]
|
||||
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"])
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile structural tags schema: {e}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
|
||||
class XGrammarChecker(BaseChecker):
|
||||
"""
|
||||
XGrammar-specific implementation of BaseChecker.
|
||||
|
||||
This validator checks and formats various schema types (JSON, regex, grammar)
|
||||
for compatibility with XGrammar before processing.
|
||||
|
||||
Attributes:
|
||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
|
||||
|
||||
def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if JSON schema contains unsupported features.
|
||||
|
||||
Args:
|
||||
schema (dict[str, Any]): JSON schema to validate
|
||||
|
||||
Returns:
|
||||
bool: True if schema contains unsupported features, False otherwise
|
||||
"""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf"
|
||||
in obj):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj
|
||||
for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
return True
|
||||
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
def schema_format(self, request: Request):
|
||||
"""
|
||||
format schema to backend specific format.
|
||||
"""
|
||||
if request.guided_json:
|
||||
try:
|
||||
if not isinstance(request.guided_json, str):
|
||||
guided_json = json.dumps(request.guided_json)
|
||||
else:
|
||||
guided_json = request.guided_json
|
||||
|
||||
Grammar.from_json_schema(guided_json,
|
||||
any_whitespace=self.any_whitespace)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
|
||||
if self._unsupported_json_schema(guided_json):
|
||||
err_msg = f"unsupported JSON schema: {guided_json}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_json = guided_json
|
||||
return request, None
|
||||
elif request.guided_grammar:
|
||||
# TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
|
||||
guided_grammar = request.guided_grammar
|
||||
try:
|
||||
Grammar.from_ebnf(guided_grammar)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
request.guided_grammar = guided_grammar
|
||||
return request, None
|
||||
elif request.guided_json_object:
|
||||
request.guided_json = '{"type": "object"}'
|
||||
return request, None
|
||||
elif request.guided_choice:
|
||||
try:
|
||||
escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
|
||||
for c in request.guided_choice)
|
||||
guided_choice = ('root ::= ' +
|
||||
' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
|
||||
Grammar.from_ebnf(guided_choice)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_grammar = guided_choice
|
||||
return request, None
|
||||
elif request.structural_tag:
|
||||
try:
|
||||
structural_tag = json.loads(request.structural_tag)
|
||||
tags = [
|
||||
StructuralTagItem(
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in structural_tag["structures"]
|
||||
]
|
||||
Grammar.from_structural_tag(tags, structural_tag["triggers"])
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
|
||||
return request, err_msg
|
||||
return request, None
|
||||
else:
|
||||
# regex is not format
|
||||
return request, None
|
Reference in New Issue
Block a user