mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-25 01:20:43 +08:00

Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* add error traceback info * update error msg * update code --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
470 lines
16 KiB
Python
470 lines
16 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import traceback
|
|
from typing import Any, List, Optional
|
|
|
|
import paddle
|
|
import torch
|
|
|
|
from fastdeploy.config import FDConfig
|
|
from fastdeploy.engine.request import Request
|
|
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
|
BackendBase,
|
|
BaseChecker,
|
|
LogitsProcessorBase,
|
|
)
|
|
from fastdeploy.utils import llm_logger
|
|
|
|
try:
|
|
from xgrammar import (
|
|
CompiledGrammar,
|
|
Grammar,
|
|
GrammarCompiler,
|
|
GrammarMatcher,
|
|
StructuralTagItem,
|
|
TokenizerInfo,
|
|
allocate_token_bitmask,
|
|
apply_token_bitmask_inplace,
|
|
)
|
|
except Exception as e:
|
|
raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")
|
|
|
|
|
|
class XGrammarProcessor(LogitsProcessorBase):
|
|
"""
|
|
XGrammar-specific implementation of LogitsProcessorBase.
|
|
|
|
This processor enforces grammar constraints during token generation using XGrammar.
|
|
It manages the grammar matching state and applies token masks to logits.
|
|
|
|
Attributes:
|
|
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
|
vocab_size (int): Size of the vocabulary
|
|
batch_size (int): Batch size for processing
|
|
splitwise_role (str): Role for splitwise processing
|
|
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
|
terminate_without_stop_token (bool): Whether to terminate without stop token
|
|
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
|
matcher (GrammarMatcher): Grammar matching engine
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
compiled_grammar: CompiledGrammar,
|
|
terminate_without_stop_token: bool = False,
|
|
override_stop_tokens: Optional[List[int]] = None,
|
|
vocab_size: Optional[int] = None,
|
|
batch_size: Optional[int] = None,
|
|
splitwise_role: str = "mixed",
|
|
):
|
|
super().__init__()
|
|
self.max_rollback_tokens = 200
|
|
self.vocab_size = vocab_size
|
|
self.batch_size = batch_size
|
|
self.splitwise_role = splitwise_role
|
|
self.compiled_grammar = compiled_grammar
|
|
self.terminate_without_stop_token = terminate_without_stop_token
|
|
self.override_stop_tokens = override_stop_tokens
|
|
|
|
self.matcher = GrammarMatcher(
|
|
compiled_grammar=compiled_grammar,
|
|
max_rollback_tokens=self.max_rollback_tokens,
|
|
terminate_without_stop_token=terminate_without_stop_token,
|
|
override_stop_tokens=override_stop_tokens,
|
|
)
|
|
|
|
def allocate_token_bitmask(self) -> torch.Tensor:
|
|
"""
|
|
Allocate a token bitmask tensor for grammar constraints.
|
|
|
|
Returns:
|
|
torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
|
|
"""
|
|
return allocate_token_bitmask(self.batch_size, self.vocab_size)
|
|
|
|
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
|
|
"""
|
|
Fill the token bitmask with allowed tokens for the given index.
|
|
|
|
Args:
|
|
token_bitmask (torch.Tensor): The token bitmask tensor to fill
|
|
idx (int): The batch index to fill the mask for
|
|
|
|
Returns:
|
|
None: Modifies the token_bitmask in-place
|
|
"""
|
|
self.matcher.fill_next_token_bitmask(token_bitmask, idx)
|
|
|
|
def apply_token_mask(
|
|
self,
|
|
logits: paddle.Tensor,
|
|
token_bitmask: torch.Tensor,
|
|
indices: Optional[List[int]] = None,
|
|
) -> paddle.Tensor:
|
|
"""
|
|
Apply the token mask to the logits, modifying probabilities of invalid tokens.
|
|
|
|
Args:
|
|
logits (paddle.Tensor): The logits tensor to modify
|
|
token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
|
|
indices (Optional[List[int]]): Optional list of batch indices to apply mask to
|
|
|
|
Returns:
|
|
paddle.Tensor: The modified logits tensor
|
|
"""
|
|
origin_place = logits.place
|
|
origin_dtype = logits.dtype
|
|
logits = torch.from_numpy(logits.numpy())
|
|
|
|
logits = logits.float() # cpu
|
|
apply_token_bitmask_inplace(
|
|
logits=logits,
|
|
bitmask=token_bitmask.to(logits.device, non_blocking=True),
|
|
indices=indices,
|
|
)
|
|
|
|
return paddle.to_tensor(
|
|
logits.numpy(),
|
|
dtype=origin_dtype,
|
|
place=origin_place,
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Reset the grammar matcher state to initial conditions.
|
|
|
|
Returns:
|
|
None: No return value
|
|
"""
|
|
self.matcher.reset()
|
|
|
|
def accept_token(self, token: int) -> None:
|
|
"""
|
|
Validate and accept a generated token against the grammar constraints.
|
|
|
|
Args:
|
|
token (int): The token ID to validate
|
|
|
|
Raises:
|
|
AssertionError: If token is not allowed by the grammar
|
|
"""
|
|
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
|
|
|
|
def is_terminated(self) -> bool:
|
|
"""
|
|
Check if the grammar matching process has terminated.
|
|
|
|
Returns:
|
|
bool: True if matching has terminated, False otherwise
|
|
"""
|
|
return self.matcher.is_terminated()
|
|
|
|
def copy(self) -> "XGrammarProcessor":
|
|
"""
|
|
Create a deep copy of this processor instance.
|
|
|
|
Returns:
|
|
XGrammarProcessor: A new processor instance with identical state
|
|
"""
|
|
return XGrammarProcessor(
|
|
compiled_grammar=self.compiled_grammar,
|
|
terminate_without_stop_token=self.terminate_without_stop_token,
|
|
override_stop_tokens=self.override_stop_tokens,
|
|
vocab_size=self.vocab_size,
|
|
batch_size=self.batch_size,
|
|
splitwise_role=self.splitwise_role,
|
|
)
|
|
|
|
|
|
class XGrammarBackend(BackendBase):
|
|
"""
|
|
XGrammar-specific implementation of BackendBase.
|
|
|
|
This backend handles compilation of various schema types (JSON, regex, grammar)
|
|
into XGrammar processors. It manages the grammar compiler and tokenizer info.
|
|
|
|
Attributes:
|
|
vocab_size (int): Size of the vocabulary from config
|
|
batch_size (int): Maximum batch size from config
|
|
any_whitespace (bool): Whether to allow any whitespace in JSON
|
|
splitwise_role (str): Role for splitwise processing
|
|
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
fd_config: FDConfig,
|
|
**kwargs,
|
|
):
|
|
super().__init__(fd_config=fd_config)
|
|
self.vocab_size = fd_config.model_config.vocab_size
|
|
self.batch_size = fd_config.parallel_config.max_num_seqs
|
|
|
|
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
|
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
|
|
|
try:
|
|
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
|
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
|
except Exception as e:
|
|
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
|
|
|
def _create_processor(
|
|
self,
|
|
compiled_grammar: CompiledGrammar,
|
|
terminate_without_stop_token: bool = False,
|
|
override_stop_tokens: Optional[List[int]] = None,
|
|
) -> XGrammarProcessor:
|
|
"""
|
|
Create a logits processor instance for the given compiled grammar.
|
|
|
|
Args:
|
|
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
|
terminate_without_stop_token (bool): Whether to terminate without stop token
|
|
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
|
|
|
Returns:
|
|
XGrammarProcessor: Configured grammar processor instance
|
|
"""
|
|
return XGrammarProcessor(
|
|
compiled_grammar=compiled_grammar,
|
|
terminate_without_stop_token=terminate_without_stop_token,
|
|
override_stop_tokens=override_stop_tokens,
|
|
vocab_size=self.vocab_size,
|
|
batch_size=self.batch_size,
|
|
splitwise_role=self.splitwise_role,
|
|
)
|
|
|
|
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
|
"""
|
|
Compile JSON schema into a grammar processor.
|
|
|
|
Args:
|
|
schemata (str): JSON schema string to compile
|
|
|
|
Returns:
|
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
|
"""
|
|
try:
|
|
compiled_grammar = self.grammar_compiler.compile_json_schema(schemata, any_whitespace=self.any_whitespace)
|
|
except Exception as e:
|
|
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
|
|
return None
|
|
return self._create_processor(compiled_grammar)
|
|
|
|
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
|
"""
|
|
Compile regex pattern into a grammar processor.
|
|
|
|
Args:
|
|
schemata (str): Regex pattern string to compile
|
|
|
|
Returns:
|
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
|
"""
|
|
try:
|
|
compiled_grammar = self.grammar_compiler.compile_regex(schemata)
|
|
except Exception as e:
|
|
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
|
|
return None
|
|
return self._create_processor(compiled_grammar)
|
|
|
|
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
|
"""
|
|
Compile grammar (EBNF) into a grammar processor.
|
|
|
|
Args:
|
|
schemata (str): Grammar string in EBNF format
|
|
|
|
Returns:
|
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
|
"""
|
|
try:
|
|
compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
|
|
except Exception as e:
|
|
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
|
|
return None
|
|
return self._create_processor(compiled_grammar)
|
|
|
|
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
|
"""
|
|
Compile structural tags into a grammar processor.
|
|
|
|
Args:
|
|
schemata (str): JSON string containing structural tag definitions
|
|
|
|
Returns:
|
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
|
"""
|
|
try:
|
|
structural_tag = json.loads(schemata)
|
|
tags = [
|
|
StructuralTagItem(
|
|
begin=structure["begin"],
|
|
schema=json.dumps(structure["schema"]),
|
|
end=structure["end"],
|
|
)
|
|
for structure in structural_tag["structures"]
|
|
]
|
|
|
|
compiled_grammar = self.grammar_compiler.compile_structural_tag(tags, structural_tag["triggers"])
|
|
except Exception as e:
|
|
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
|
|
return None
|
|
return self._create_processor(compiled_grammar)
|
|
|
|
|
|
class XGrammarChecker(BaseChecker):
|
|
"""
|
|
XGrammar-specific implementation of BaseChecker.
|
|
|
|
This validator checks and formats various schema types (JSON, regex, grammar)
|
|
for compatibility with XGrammar before processing.
|
|
|
|
Attributes:
|
|
any_whitespace (bool): Whether to allow any whitespace in JSON
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__()
|
|
|
|
self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
|
|
|
|
def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
|
|
"""
|
|
Check if JSON schema contains unsupported features.
|
|
|
|
Args:
|
|
schema (dict[str, Any]): JSON schema to validate
|
|
|
|
Returns:
|
|
bool: True if schema contains unsupported features, False otherwise
|
|
"""
|
|
|
|
def check_object(obj: dict[str, Any]) -> bool:
|
|
if not isinstance(obj, dict):
|
|
return False
|
|
|
|
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
|
return True
|
|
|
|
if obj.get("type") == "array" and any(
|
|
key in obj
|
|
for key in (
|
|
"uniqueItems",
|
|
"contains",
|
|
"minContains",
|
|
"maxContains",
|
|
)
|
|
):
|
|
return True
|
|
|
|
if obj.get("type") == "string" and "format" in obj:
|
|
return True
|
|
|
|
if obj.get("type") == "object" and any(
|
|
key in obj
|
|
for key in (
|
|
"minProperties",
|
|
"maxProperties",
|
|
"propertyNames",
|
|
"patternProperties",
|
|
)
|
|
):
|
|
return True
|
|
|
|
for value in obj.values():
|
|
if isinstance(value, dict):
|
|
if check_object(value):
|
|
return True
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
if isinstance(item, dict) and check_object(item):
|
|
return True
|
|
return False
|
|
|
|
return check_object(schema)
|
|
|
|
def schema_format(self, request: Request):
|
|
"""
|
|
format schema to backend specific format.
|
|
"""
|
|
if request.guided_json:
|
|
try:
|
|
if not isinstance(request.guided_json, str):
|
|
guided_json = json.dumps(request.guided_json)
|
|
else:
|
|
guided_json = request.guided_json
|
|
|
|
Grammar.from_json_schema(guided_json, any_whitespace=self.any_whitespace)
|
|
except RuntimeError as e:
|
|
err_msg = f"Invalid JSON format: {guided_json}, error message: {e!s}"
|
|
return request, err_msg
|
|
|
|
if self._unsupported_json_schema(guided_json):
|
|
err_msg = f"unsupported JSON schema: {guided_json}"
|
|
return request, err_msg
|
|
|
|
request.guided_json = guided_json
|
|
return request, None
|
|
elif request.guided_grammar:
|
|
# TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
|
|
guided_grammar = request.guided_grammar
|
|
try:
|
|
Grammar.from_ebnf(guided_grammar)
|
|
except RuntimeError as e:
|
|
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {e!s}"
|
|
return request, err_msg
|
|
request.guided_grammar = guided_grammar
|
|
return request, None
|
|
elif request.guided_json_object:
|
|
request.guided_json = '{"type": "object"}'
|
|
return request, None
|
|
elif request.guided_choice:
|
|
try:
|
|
escaped_choices = (re.sub(r'(["\\])', r"\\\1", c) for c in request.guided_choice)
|
|
guided_choice = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
|
|
|
|
Grammar.from_ebnf(guided_choice)
|
|
except RuntimeError as e:
|
|
err_msg = f"Invalid choice format: {guided_choice}, error message: {e!s}"
|
|
return request, err_msg
|
|
|
|
request.guided_grammar = guided_choice
|
|
return request, None
|
|
elif request.structural_tag:
|
|
try:
|
|
structural_tag = json.loads(request.structural_tag)
|
|
tags = [
|
|
StructuralTagItem(
|
|
begin=s["begin"],
|
|
schema=json.dumps(s["schema"]),
|
|
end=s["end"],
|
|
)
|
|
for s in structural_tag["structures"]
|
|
]
|
|
Grammar.from_structural_tag(tags, structural_tag["triggers"])
|
|
except RuntimeError as e:
|
|
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {e!s}"
|
|
return request, err_msg
|
|
return request, None
|
|
else:
|
|
# regex is not format
|
|
return request, None
|