mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-27 02:20:31 +08:00 
			
		
		
		
	 67298cf4c0
			
		
	
	67298cf4c0
	
	
		
			
	
		
	
	
		
			Some checks failed
		
		
	
	Deploy GitHub Pages / deploy (push) Has been cancelled
				
			* add error traceback info * update error msg * update code --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
		
			
				
	
	
		
			470 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			470 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License"
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| """
 | |
| 
 | |
| import json
 | |
| import re
 | |
| import traceback
 | |
| from typing import Any, List, Optional
 | |
| 
 | |
| import paddle
 | |
| import torch
 | |
| 
 | |
| from fastdeploy.config import FDConfig
 | |
| from fastdeploy.engine.request import Request
 | |
| from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
 | |
|     BackendBase,
 | |
|     BaseChecker,
 | |
|     LogitsProcessorBase,
 | |
| )
 | |
| from fastdeploy.utils import llm_logger
 | |
| 
 | |
| try:
 | |
|     from xgrammar import (
 | |
|         CompiledGrammar,
 | |
|         Grammar,
 | |
|         GrammarCompiler,
 | |
|         GrammarMatcher,
 | |
|         StructuralTagItem,
 | |
|         TokenizerInfo,
 | |
|         allocate_token_bitmask,
 | |
|         apply_token_bitmask_inplace,
 | |
|     )
 | |
| except Exception as e:
 | |
|     raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")
 | |
| 
 | |
| 
 | |
| class XGrammarProcessor(LogitsProcessorBase):
 | |
|     """
 | |
|     XGrammar-specific implementation of LogitsProcessorBase.
 | |
| 
 | |
|     This processor enforces grammar constraints during token generation using XGrammar.
 | |
|     It manages the grammar matching state and applies token masks to logits.
 | |
| 
 | |
|     Attributes:
 | |
|         max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
 | |
|         vocab_size (int): Size of the vocabulary
 | |
|         batch_size (int): Batch size for processing
 | |
|         splitwise_role (str): Role for splitwise processing
 | |
|         compiled_grammar (CompiledGrammar): Compiled grammar rules
 | |
|         terminate_without_stop_token (bool): Whether to terminate without stop token
 | |
|         override_stop_tokens (Optional[List[int]]): Custom stop tokens
 | |
|         matcher (GrammarMatcher): Grammar matching engine
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         compiled_grammar: CompiledGrammar,
 | |
|         terminate_without_stop_token: bool = False,
 | |
|         override_stop_tokens: Optional[List[int]] = None,
 | |
|         vocab_size: Optional[int] = None,
 | |
|         batch_size: Optional[int] = None,
 | |
|         splitwise_role: str = "mixed",
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.max_rollback_tokens = 200
 | |
|         self.vocab_size = vocab_size
 | |
|         self.batch_size = batch_size
 | |
|         self.splitwise_role = splitwise_role
 | |
|         self.compiled_grammar = compiled_grammar
 | |
|         self.terminate_without_stop_token = terminate_without_stop_token
 | |
|         self.override_stop_tokens = override_stop_tokens
 | |
| 
 | |
|         self.matcher = GrammarMatcher(
 | |
|             compiled_grammar=compiled_grammar,
 | |
|             max_rollback_tokens=self.max_rollback_tokens,
 | |
|             terminate_without_stop_token=terminate_without_stop_token,
 | |
|             override_stop_tokens=override_stop_tokens,
 | |
|         )
 | |
| 
 | |
|     def allocate_token_bitmask(self) -> torch.Tensor:
 | |
|         """
 | |
|         Allocate a token bitmask tensor for grammar constraints.
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: A tensor of shape (batch_size, vocab_size) initialized to 0
 | |
|         """
 | |
|         return allocate_token_bitmask(self.batch_size, self.vocab_size)
 | |
| 
 | |
|     def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
 | |
|         """
 | |
|         Fill the token bitmask with allowed tokens for the given index.
 | |
| 
 | |
|         Args:
 | |
|             token_bitmask (torch.Tensor): The token bitmask tensor to fill
 | |
|             idx (int): The batch index to fill the mask for
 | |
| 
 | |
|         Returns:
 | |
|             None: Modifies the token_bitmask in-place
 | |
|         """
 | |
|         self.matcher.fill_next_token_bitmask(token_bitmask, idx)
 | |
| 
 | |
|     def apply_token_mask(
 | |
|         self,
 | |
|         logits: paddle.Tensor,
 | |
|         token_bitmask: torch.Tensor,
 | |
|         indices: Optional[List[int]] = None,
 | |
|     ) -> paddle.Tensor:
 | |
|         """
 | |
|         Apply the token mask to the logits, modifying probabilities of invalid tokens.
 | |
| 
 | |
|         Args:
 | |
|             logits (paddle.Tensor): The logits tensor to modify
 | |
|             token_bitmask (torch.Tensor): The token bitmask indicating allowed tokens
 | |
|             indices (Optional[List[int]]): Optional list of batch indices to apply mask to
 | |
| 
 | |
|         Returns:
 | |
|             paddle.Tensor: The modified logits tensor
 | |
|         """
 | |
|         origin_place = logits.place
 | |
|         origin_dtype = logits.dtype
 | |
|         logits = torch.from_numpy(logits.numpy())
 | |
| 
 | |
|         logits = logits.float()  # cpu
 | |
|         apply_token_bitmask_inplace(
 | |
|             logits=logits,
 | |
|             bitmask=token_bitmask.to(logits.device, non_blocking=True),
 | |
|             indices=indices,
 | |
|         )
 | |
| 
 | |
|         return paddle.to_tensor(
 | |
|             logits.numpy(),
 | |
|             dtype=origin_dtype,
 | |
|             place=origin_place,
 | |
|         )
 | |
| 
 | |
|     def reset(self) -> None:
 | |
|         """
 | |
|         Reset the grammar matcher state to initial conditions.
 | |
| 
 | |
|         Returns:
 | |
|             None: No return value
 | |
|         """
 | |
|         self.matcher.reset()
 | |
| 
 | |
|     def accept_token(self, token: int) -> None:
 | |
|         """
 | |
|         Validate and accept a generated token against the grammar constraints.
 | |
| 
 | |
|         Args:
 | |
|             token (int): The token ID to validate
 | |
| 
 | |
|         Raises:
 | |
|             AssertionError: If token is not allowed by the grammar
 | |
|         """
 | |
|         assert self.matcher.accept_token(token), f"Failed to accept token {token}"
 | |
| 
 | |
|     def is_terminated(self) -> bool:
 | |
|         """
 | |
|         Check if the grammar matching process has terminated.
 | |
| 
 | |
|         Returns:
 | |
|             bool: True if matching has terminated, False otherwise
 | |
|         """
 | |
|         return self.matcher.is_terminated()
 | |
| 
 | |
|     def copy(self) -> "XGrammarProcessor":
 | |
|         """
 | |
|         Create a deep copy of this processor instance.
 | |
| 
 | |
|         Returns:
 | |
|             XGrammarProcessor: A new processor instance with identical state
 | |
|         """
 | |
|         return XGrammarProcessor(
 | |
|             compiled_grammar=self.compiled_grammar,
 | |
|             terminate_without_stop_token=self.terminate_without_stop_token,
 | |
|             override_stop_tokens=self.override_stop_tokens,
 | |
|             vocab_size=self.vocab_size,
 | |
|             batch_size=self.batch_size,
 | |
|             splitwise_role=self.splitwise_role,
 | |
|         )
 | |
| 
 | |
| 
 | |
| class XGrammarBackend(BackendBase):
 | |
|     """
 | |
|     XGrammar-specific implementation of BackendBase.
 | |
| 
 | |
|     This backend handles compilation of various schema types (JSON, regex, grammar)
 | |
|     into XGrammar processors. It manages the grammar compiler and tokenizer info.
 | |
| 
 | |
|     Attributes:
 | |
|         vocab_size (int): Size of the vocabulary from config
 | |
|         batch_size (int): Maximum batch size from config
 | |
|         any_whitespace (bool): Whether to allow any whitespace in JSON
 | |
|         splitwise_role (str): Role for splitwise processing
 | |
|         grammar_compiler (GrammarCompiler): Grammar compilation engine
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         fd_config: FDConfig,
 | |
|         **kwargs,
 | |
|     ):
 | |
|         super().__init__(fd_config=fd_config)
 | |
|         self.vocab_size = fd_config.model_config.vocab_size
 | |
|         self.batch_size = fd_config.parallel_config.max_num_seqs
 | |
| 
 | |
|         self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
 | |
|         self.splitwise_role = fd_config.parallel_config.splitwise_role
 | |
| 
 | |
|         try:
 | |
|             tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
 | |
|             self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
 | |
|         except Exception as e:
 | |
|             raise Exception(f"Failed to load XGrammar tokenizer: {e}")
 | |
| 
 | |
|     def _create_processor(
 | |
|         self,
 | |
|         compiled_grammar: CompiledGrammar,
 | |
|         terminate_without_stop_token: bool = False,
 | |
|         override_stop_tokens: Optional[List[int]] = None,
 | |
|     ) -> XGrammarProcessor:
 | |
|         """
 | |
|         Create a logits processor instance for the given compiled grammar.
 | |
| 
 | |
|         Args:
 | |
|             compiled_grammar (CompiledGrammar): Compiled grammar rules
 | |
|             terminate_without_stop_token (bool): Whether to terminate without stop token
 | |
|             override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
 | |
| 
 | |
|         Returns:
 | |
|             XGrammarProcessor: Configured grammar processor instance
 | |
|         """
 | |
|         return XGrammarProcessor(
 | |
|             compiled_grammar=compiled_grammar,
 | |
|             terminate_without_stop_token=terminate_without_stop_token,
 | |
|             override_stop_tokens=override_stop_tokens,
 | |
|             vocab_size=self.vocab_size,
 | |
|             batch_size=self.batch_size,
 | |
|             splitwise_role=self.splitwise_role,
 | |
|         )
 | |
| 
 | |
|     def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
 | |
|         """
 | |
|         Compile JSON schema into a grammar processor.
 | |
| 
 | |
|         Args:
 | |
|             schemata (str): JSON schema string to compile
 | |
| 
 | |
|         Returns:
 | |
|             Optional[XGrammarProcessor]: Configured processor if successful, None on failure
 | |
|         """
 | |
|         try:
 | |
|             compiled_grammar = self.grammar_compiler.compile_json_schema(schemata, any_whitespace=self.any_whitespace)
 | |
|         except Exception as e:
 | |
|             llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
 | |
|             return None
 | |
|         return self._create_processor(compiled_grammar)
 | |
| 
 | |
|     def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
 | |
|         """
 | |
|         Compile regex pattern into a grammar processor.
 | |
| 
 | |
|         Args:
 | |
|             schemata (str): Regex pattern string to compile
 | |
| 
 | |
|         Returns:
 | |
|             Optional[XGrammarProcessor]: Configured processor if successful, None on failure
 | |
|         """
 | |
|         try:
 | |
|             compiled_grammar = self.grammar_compiler.compile_regex(schemata)
 | |
|         except Exception as e:
 | |
|             llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
 | |
|             return None
 | |
|         return self._create_processor(compiled_grammar)
 | |
| 
 | |
|     def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
 | |
|         """
 | |
|         Compile grammar (EBNF) into a grammar processor.
 | |
| 
 | |
|         Args:
 | |
|             schemata (str): Grammar string in EBNF format
 | |
| 
 | |
|         Returns:
 | |
|             Optional[XGrammarProcessor]: Configured processor if successful, None on failure
 | |
|         """
 | |
|         try:
 | |
|             compiled_grammar = self.grammar_compiler.compile_grammar(schemata)
 | |
|         except Exception as e:
 | |
|             llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
 | |
|             return None
 | |
|         return self._create_processor(compiled_grammar)
 | |
| 
 | |
|     def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
 | |
|         """
 | |
|         Compile structural tags into a grammar processor.
 | |
| 
 | |
|         Args:
 | |
|             schemata (str): JSON string containing structural tag definitions
 | |
| 
 | |
|         Returns:
 | |
|             Optional[XGrammarProcessor]: Configured processor if successful, None on failure
 | |
|         """
 | |
|         try:
 | |
|             structural_tag = json.loads(schemata)
 | |
|             tags = [
 | |
|                 StructuralTagItem(
 | |
|                     begin=structure["begin"],
 | |
|                     schema=json.dumps(structure["schema"]),
 | |
|                     end=structure["end"],
 | |
|                 )
 | |
|                 for structure in structural_tag["structures"]
 | |
|             ]
 | |
| 
 | |
|             compiled_grammar = self.grammar_compiler.compile_structural_tag(tags, structural_tag["triggers"])
 | |
|         except Exception as e:
 | |
|             llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
 | |
|             return None
 | |
|         return self._create_processor(compiled_grammar)
 | |
| 
 | |
| 
 | |
| class XGrammarChecker(BaseChecker):
 | |
|     """
 | |
|     XGrammar-specific implementation of BaseChecker.
 | |
| 
 | |
|     This validator checks and formats various schema types (JSON, regex, grammar)
 | |
|     for compatibility with XGrammar before processing.
 | |
| 
 | |
|     Attributes:
 | |
|         any_whitespace (bool): Whether to allow any whitespace in JSON
 | |
|     """
 | |
| 
 | |
|     def __init__(self, **kwargs):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.any_whitespace = not kwargs.get("disable_any_whitespace", True)
 | |
| 
 | |
|     def _unsupported_json_schema(self, schema: dict[str, Any]) -> bool:
 | |
|         """
 | |
|         Check if JSON schema contains unsupported features.
 | |
| 
 | |
|         Args:
 | |
|             schema (dict[str, Any]): JSON schema to validate
 | |
| 
 | |
|         Returns:
 | |
|             bool: True if schema contains unsupported features, False otherwise
 | |
|         """
 | |
| 
 | |
|         def check_object(obj: dict[str, Any]) -> bool:
 | |
|             if not isinstance(obj, dict):
 | |
|                 return False
 | |
| 
 | |
|             if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
 | |
|                 return True
 | |
| 
 | |
|             if obj.get("type") == "array" and any(
 | |
|                 key in obj
 | |
|                 for key in (
 | |
|                     "uniqueItems",
 | |
|                     "contains",
 | |
|                     "minContains",
 | |
|                     "maxContains",
 | |
|                 )
 | |
|             ):
 | |
|                 return True
 | |
| 
 | |
|             if obj.get("type") == "string" and "format" in obj:
 | |
|                 return True
 | |
| 
 | |
|             if obj.get("type") == "object" and any(
 | |
|                 key in obj
 | |
|                 for key in (
 | |
|                     "minProperties",
 | |
|                     "maxProperties",
 | |
|                     "propertyNames",
 | |
|                     "patternProperties",
 | |
|                 )
 | |
|             ):
 | |
|                 return True
 | |
| 
 | |
|             for value in obj.values():
 | |
|                 if isinstance(value, dict):
 | |
|                     if check_object(value):
 | |
|                         return True
 | |
|                 elif isinstance(value, list):
 | |
|                     for item in value:
 | |
|                         if isinstance(item, dict) and check_object(item):
 | |
|                             return True
 | |
|             return False
 | |
| 
 | |
|         return check_object(schema)
 | |
| 
 | |
|     def schema_format(self, request: Request):
 | |
|         """
 | |
|         format schema to backend specific format.
 | |
|         """
 | |
|         if request.guided_json:
 | |
|             try:
 | |
|                 if not isinstance(request.guided_json, str):
 | |
|                     guided_json = json.dumps(request.guided_json)
 | |
|                 else:
 | |
|                     guided_json = request.guided_json
 | |
| 
 | |
|                 Grammar.from_json_schema(guided_json, any_whitespace=self.any_whitespace)
 | |
|             except RuntimeError as e:
 | |
|                 err_msg = f"Invalid JSON format: {guided_json}, error message: {e!s}"
 | |
|                 return request, err_msg
 | |
| 
 | |
|             if self._unsupported_json_schema(guided_json):
 | |
|                 err_msg = f"unsupported JSON schema: {guided_json}"
 | |
|                 return request, err_msg
 | |
| 
 | |
|             request.guided_json = guided_json
 | |
|             return request, None
 | |
|         elif request.guided_grammar:
 | |
|             # TODO: XGrammar only supports GBNF grammars, convert Lark to GBNF
 | |
|             guided_grammar = request.guided_grammar
 | |
|             try:
 | |
|                 Grammar.from_ebnf(guided_grammar)
 | |
|             except RuntimeError as e:
 | |
|                 err_msg = f"Invalid grammar format: {guided_grammar}, error message: {e!s}"
 | |
|                 return request, err_msg
 | |
|             request.guided_grammar = guided_grammar
 | |
|             return request, None
 | |
|         elif request.guided_json_object:
 | |
|             request.guided_json = '{"type": "object"}'
 | |
|             return request, None
 | |
|         elif request.guided_choice:
 | |
|             try:
 | |
|                 escaped_choices = (re.sub(r'(["\\])', r"\\\1", c) for c in request.guided_choice)
 | |
|                 guided_choice = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
 | |
| 
 | |
|                 Grammar.from_ebnf(guided_choice)
 | |
|             except RuntimeError as e:
 | |
|                 err_msg = f"Invalid choice format: {guided_choice}, error message: {e!s}"
 | |
|                 return request, err_msg
 | |
| 
 | |
|             request.guided_grammar = guided_choice
 | |
|             return request, None
 | |
|         elif request.structural_tag:
 | |
|             try:
 | |
|                 structural_tag = json.loads(request.structural_tag)
 | |
|                 tags = [
 | |
|                     StructuralTagItem(
 | |
|                         begin=s["begin"],
 | |
|                         schema=json.dumps(s["schema"]),
 | |
|                         end=s["end"],
 | |
|                     )
 | |
|                     for s in structural_tag["structures"]
 | |
|                 ]
 | |
|                 Grammar.from_structural_tag(tags, structural_tag["triggers"])
 | |
|             except RuntimeError as e:
 | |
|                 err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {e!s}"
 | |
|                 return request, err_msg
 | |
|             return request, None
 | |
|         else:
 | |
|             # regex is not format
 | |
|             return request, None
 |