mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -24,16 +24,25 @@ import torch
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
BackendBase, BaseChecker, LogitsProcessorBase)
|
||||
BackendBase,
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
|
||||
GrammarMatcher, StructuralTagItem, TokenizerInfo,
|
||||
allocate_token_bitmask, apply_token_bitmask_inplace)
|
||||
from xgrammar import (
|
||||
CompiledGrammar,
|
||||
Grammar,
|
||||
GrammarCompiler,
|
||||
GrammarMatcher,
|
||||
StructuralTagItem,
|
||||
TokenizerInfo,
|
||||
allocate_token_bitmask,
|
||||
apply_token_bitmask_inplace,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
|
||||
|
||||
class XGrammarProcessor(LogitsProcessorBase):
|
||||
@@ -88,8 +97,7 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
return allocate_token_bitmask(self.batch_size, self.vocab_size)
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor,
|
||||
idx: int) -> None:
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
|
||||
"""
|
||||
Fill the token bitmask with allowed tokens for the given index.
|
||||
|
||||
@@ -155,8 +163,7 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
Raises:
|
||||
AssertionError: If token is not allowed by the grammar
|
||||
"""
|
||||
assert self.matcher.accept_token(
|
||||
token), f"Failed to accept token {token}"
|
||||
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
@@ -212,10 +219,8 @@ class XGrammarBackend(BackendBase):
|
||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(
|
||||
tokenizer_info=tokenizer_info)
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
||||
|
||||
@@ -256,8 +261,7 @@ class XGrammarBackend(BackendBase):
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(
|
||||
schemata, any_whitespace=self.any_whitespace)
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(schemata, any_whitespace=self.any_whitespace)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile json schema: {e}")
|
||||
return None
|
||||
@@ -297,8 +301,7 @@ class XGrammarBackend(BackendBase):
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _structural_tag_processor(
|
||||
self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile structural tags into a grammar processor.
|
||||
|
||||
@@ -315,11 +318,11 @@ class XGrammarBackend(BackendBase):
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
) for structure in structural_tag["structures"]
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"])
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(tags, structural_tag["triggers"])
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile structural tags schema: {e}")
|
||||
return None
|
||||
@@ -357,22 +360,32 @@ class XGrammarChecker(BaseChecker):
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf"
|
||||
in obj):
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
key in obj
|
||||
for key in (
|
||||
"uniqueItems",
|
||||
"contains",
|
||||
"minContains",
|
||||
"maxContains",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj
|
||||
for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
key in obj
|
||||
for key in (
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
"propertyNames",
|
||||
"patternProperties",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
for value in obj.values():
|
||||
@@ -398,10 +411,9 @@ class XGrammarChecker(BaseChecker):
|
||||
else:
|
||||
guided_json = request.guided_json
|
||||
|
||||
Grammar.from_json_schema(guided_json,
|
||||
any_whitespace=self.any_whitespace)
|
||||
Grammar.from_json_schema(guided_json, any_whitespace=self.any_whitespace)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
|
||||
if self._unsupported_json_schema(guided_json):
|
||||
@@ -416,7 +428,7 @@ class XGrammarChecker(BaseChecker):
|
||||
try:
|
||||
Grammar.from_ebnf(guided_grammar)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
request.guided_grammar = guided_grammar
|
||||
return request, None
|
||||
@@ -425,14 +437,12 @@ class XGrammarChecker(BaseChecker):
|
||||
return request, None
|
||||
elif request.guided_choice:
|
||||
try:
|
||||
escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
|
||||
for c in request.guided_choice)
|
||||
guided_choice = ('root ::= ' +
|
||||
' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
escaped_choices = (re.sub(r'(["\\])', r"\\\1", c) for c in request.guided_choice)
|
||||
guided_choice = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
|
||||
|
||||
Grammar.from_ebnf(guided_choice)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_grammar = guided_choice
|
||||
@@ -445,11 +455,12 @@ class XGrammarChecker(BaseChecker):
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in structural_tag["structures"]
|
||||
)
|
||||
for s in structural_tag["structures"]
|
||||
]
|
||||
Grammar.from_structural_tag(tags, structural_tag["triggers"])
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
return request, None
|
||||
else:
|
||||
|
Reference in New Issue
Block a user