mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
# from fastdeploy.config import FDConfig
|
||||
|
||||
__all__ = ['get_guided_backend', 'schema_checker']
|
||||
__all__ = ["get_guided_backend", "schema_checker"]
|
||||
|
||||
|
||||
def get_guided_backend(
|
||||
@@ -37,8 +37,10 @@ def get_guided_backend(
|
||||
ValueError: If the specified backend is not supported
|
||||
"""
|
||||
if fd_config.parallel_config.guided_decoding_backend.lower() == "xgrammar":
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
|
||||
XGrammarBackend
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
|
||||
XGrammarBackend,
|
||||
)
|
||||
|
||||
return XGrammarBackend(
|
||||
fd_config=fd_config,
|
||||
**kwargs,
|
||||
@@ -46,7 +48,8 @@ def get_guided_backend(
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Get unsupported backend {fd_config.parallel_config.guided_decoding_backend},"
|
||||
f" please check your configuration.")
|
||||
f" please check your configuration."
|
||||
)
|
||||
|
||||
|
||||
def schema_checker(backend_name: str, **kwargs):
|
||||
@@ -64,10 +67,10 @@ def schema_checker(backend_name: str, **kwargs):
|
||||
ValueError: If the specified backend is not supported
|
||||
"""
|
||||
if backend_name.lower() == "xgrammar":
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import \
|
||||
XGrammarChecker
|
||||
from fastdeploy.model_executor.guided_decoding.xgrammar_backend import (
|
||||
XGrammarChecker,
|
||||
)
|
||||
|
||||
return XGrammarChecker(**kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Get unsupported backend {backend_name}, please check your configuration."
|
||||
)
|
||||
raise ValueError(f"Get unsupported backend {backend_name}, please check your configuration.")
|
||||
|
@@ -17,7 +17,7 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastdeploy.config import FDConfig, ErnieArchitectures
|
||||
from fastdeploy.config import ErnieArchitectures, FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
@@ -48,7 +48,7 @@ class LogitsProcessorBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def apply_token_mask(self, logits, token_bitmask):
|
||||
"""
|
||||
@@ -61,7 +61,7 @@ class LogitsProcessorBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def allocate_token_bitmask(self, batch_size, vocab_size):
|
||||
"""
|
||||
@@ -74,7 +74,7 @@ class LogitsProcessorBase:
|
||||
Returns:
|
||||
tensor: The allocated token bitmask.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def accept_token(self, token):
|
||||
"""
|
||||
@@ -86,7 +86,7 @@ class LogitsProcessorBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def is_terminated(self):
|
||||
"""
|
||||
@@ -95,13 +95,13 @@ class LogitsProcessorBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the matcher state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class LogitsProcessorBase:
|
||||
Returns:
|
||||
BackendBase: A copy of the backend instance.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BackendBase:
|
||||
@@ -146,7 +146,7 @@ class BackendBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _json_processor(self, schemata):
|
||||
"""
|
||||
@@ -158,7 +158,7 @@ class BackendBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _regex_processor(self, schemata):
|
||||
"""
|
||||
@@ -170,7 +170,7 @@ class BackendBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _grammar_processor(self, schemata):
|
||||
"""
|
||||
@@ -182,7 +182,7 @@ class BackendBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _structural_tag_processor(self, schemata):
|
||||
"""
|
||||
@@ -194,7 +194,7 @@ class BackendBase:
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsupported_processor_type(self, key_type, schemata):
|
||||
"""
|
||||
@@ -206,8 +206,7 @@ class BackendBase:
|
||||
"""
|
||||
raise Exception(f"Unsupported processor type {key_type}.")
|
||||
|
||||
def _init_logits_processor(
|
||||
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
||||
def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
||||
"""
|
||||
init logits processor by type and schemata.
|
||||
|
||||
@@ -233,9 +232,7 @@ class BackendBase:
|
||||
llm_logger.error(f"Unsupported processor type {key_type}.")
|
||||
return None
|
||||
|
||||
def get_logits_processor(
|
||||
self,
|
||||
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
||||
def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
||||
"""
|
||||
get logits processor by key from cache or create new one.
|
||||
|
||||
@@ -271,39 +268,41 @@ class BackendBase:
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
__slow_tokenizer=tokenizer)
|
||||
tokenizer = PreTrainedTokenizerFast(__slow_tokenizer=tokenizer)
|
||||
else:
|
||||
from fastdeploy.model_executor.guided_decoding.ernie_tokenizer import \
|
||||
ErnieBotTokenizer
|
||||
from fastdeploy.model_executor.guided_decoding.ernie_tokenizer import (
|
||||
ErnieBotTokenizer,
|
||||
)
|
||||
|
||||
vocab_file_names = [
|
||||
"tokenizer.model", "spm.model", "ernie_token_100k.model"
|
||||
"tokenizer.model",
|
||||
"spm.model",
|
||||
"ernie_token_100k.model",
|
||||
]
|
||||
for i in range(len(vocab_file_names)):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
self.fd_config.parallel_config.
|
||||
model_name_or_path, vocab_file_names[i])):
|
||||
ErnieBotTokenizer.vocab_files_names[
|
||||
"vocab_file"] = vocab_file_names[i]
|
||||
os.path.join(
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
vocab_file_names[i],
|
||||
)
|
||||
):
|
||||
ErnieBotTokenizer.vocab_files_names["vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path)
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(self.fd_config.parallel_config.model_name_or_path)
|
||||
|
||||
return tokenizer
|
||||
except Exception as e:
|
||||
raise Exception(f"Fail to initialize hf tokenizer: {e}")
|
||||
|
||||
def add_cache(self, schemata_key: tuple[str, str],
|
||||
processor: LogitsProcessorBase) -> None:
|
||||
def add_cache(self, schemata_key: tuple[str, str], processor: LogitsProcessorBase) -> None:
|
||||
"""
|
||||
add logits processor to cache.
|
||||
|
||||
@@ -343,4 +342,4 @@ class BaseChecker:
|
||||
Returns:
|
||||
request (Request): request object with formatted schema.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -63,18 +64,10 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
bos_token = AddedToken(bos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
pad_token, str) else pad_token
|
||||
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
@@ -111,10 +104,7 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {
|
||||
self.convert_ids_to_tokens(i): i
|
||||
for i in range(self.vocab_size)
|
||||
}
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
@@ -126,10 +116,12 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def decode(self,
|
||||
tokens,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False):
|
||||
def decode(
|
||||
self,
|
||||
tokens,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
|
||||
@@ -161,9 +153,7 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string
|
||||
|
||||
def save_vocabulary(self,
|
||||
save_directory,
|
||||
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
Args:
|
||||
@@ -176,18 +166,17 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") +
|
||||
VOCAB_FILES_NAMES["vocab_file"])
|
||||
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file, )
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
"""
|
||||
@@ -204,10 +193,11 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False) -> List[int]:
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
@@ -225,20 +215,26 @@ class ErnieBotTokenizer(PreTrainedTokenizer):
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True)
|
||||
already_has_special_tokens=True,
|
||||
)
|
||||
|
||||
bos_token_id = [1] if self.add_bos_token else []
|
||||
eos_token_id = [1] if self.add_eos_token else []
|
||||
|
||||
if token_ids_1 is None:
|
||||
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
||||
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
|
||||
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
|
||||
return (
|
||||
bos_token_id
|
||||
+ ([0] * len(token_ids_0))
|
||||
+ eos_token_id
|
||||
+ bos_token_id
|
||||
+ ([0] * len(token_ids_1))
|
||||
+ eos_token_id
|
||||
)
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None) -> List[int]:
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
||||
sequence pair mask has the following format:
|
||||
|
@@ -24,16 +24,25 @@ import torch
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
BackendBase, BaseChecker, LogitsProcessorBase)
|
||||
BackendBase,
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
try:
|
||||
from xgrammar import (CompiledGrammar, Grammar, GrammarCompiler,
|
||||
GrammarMatcher, StructuralTagItem, TokenizerInfo,
|
||||
allocate_token_bitmask, apply_token_bitmask_inplace)
|
||||
from xgrammar import (
|
||||
CompiledGrammar,
|
||||
Grammar,
|
||||
GrammarCompiler,
|
||||
GrammarMatcher,
|
||||
StructuralTagItem,
|
||||
TokenizerInfo,
|
||||
allocate_token_bitmask,
|
||||
apply_token_bitmask_inplace,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
raise Exception(f"import XGrammar failed, please check your environment:\n\t {e}")
|
||||
|
||||
|
||||
class XGrammarProcessor(LogitsProcessorBase):
|
||||
@@ -88,8 +97,7 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
return allocate_token_bitmask(self.batch_size, self.vocab_size)
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor,
|
||||
idx: int) -> None:
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
|
||||
"""
|
||||
Fill the token bitmask with allowed tokens for the given index.
|
||||
|
||||
@@ -155,8 +163,7 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
Raises:
|
||||
AssertionError: If token is not allowed by the grammar
|
||||
"""
|
||||
assert self.matcher.accept_token(
|
||||
token), f"Failed to accept token {token}"
|
||||
assert self.matcher.accept_token(token), f"Failed to accept token {token}"
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
@@ -212,10 +219,8 @@ class XGrammarBackend(BackendBase):
|
||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(
|
||||
self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(
|
||||
tokenizer_info=tokenizer_info)
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load XGrammar tokenizer: {e}")
|
||||
|
||||
@@ -256,8 +261,7 @@ class XGrammarBackend(BackendBase):
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
"""
|
||||
try:
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(
|
||||
schemata, any_whitespace=self.any_whitespace)
|
||||
compiled_grammar = self.grammar_compiler.compile_json_schema(schemata, any_whitespace=self.any_whitespace)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile json schema: {e}")
|
||||
return None
|
||||
@@ -297,8 +301,7 @@ class XGrammarBackend(BackendBase):
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
|
||||
def _structural_tag_processor(
|
||||
self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile structural tags into a grammar processor.
|
||||
|
||||
@@ -315,11 +318,11 @@ class XGrammarBackend(BackendBase):
|
||||
begin=structure["begin"],
|
||||
schema=json.dumps(structure["schema"]),
|
||||
end=structure["end"],
|
||||
) for structure in structural_tag["structures"]
|
||||
)
|
||||
for structure in structural_tag["structures"]
|
||||
]
|
||||
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(
|
||||
tags, structural_tag["triggers"])
|
||||
compiled_grammar = self.grammar_compiler.compile_structural_tag(tags, structural_tag["triggers"])
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile structural tags schema: {e}")
|
||||
return None
|
||||
@@ -357,22 +360,32 @@ class XGrammarChecker(BaseChecker):
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf"
|
||||
in obj):
|
||||
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj for key in ("uniqueItems", "contains",
|
||||
"minContains", "maxContains")):
|
||||
key in obj
|
||||
for key in (
|
||||
"uniqueItems",
|
||||
"contains",
|
||||
"minContains",
|
||||
"maxContains",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
if obj.get("type") == "string" and "format" in obj:
|
||||
return True
|
||||
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj
|
||||
for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
key in obj
|
||||
for key in (
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
"propertyNames",
|
||||
"patternProperties",
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
||||
for value in obj.values():
|
||||
@@ -398,10 +411,9 @@ class XGrammarChecker(BaseChecker):
|
||||
else:
|
||||
guided_json = request.guided_json
|
||||
|
||||
Grammar.from_json_schema(guided_json,
|
||||
any_whitespace=self.any_whitespace)
|
||||
Grammar.from_json_schema(guided_json, any_whitespace=self.any_whitespace)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {str(e)}"
|
||||
err_msg = f"Invalid JSON format: {guided_json}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
|
||||
if self._unsupported_json_schema(guided_json):
|
||||
@@ -416,7 +428,7 @@ class XGrammarChecker(BaseChecker):
|
||||
try:
|
||||
Grammar.from_ebnf(guided_grammar)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {str(e)}"
|
||||
err_msg = f"Invalid grammar format: {guided_grammar}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
request.guided_grammar = guided_grammar
|
||||
return request, None
|
||||
@@ -425,14 +437,12 @@ class XGrammarChecker(BaseChecker):
|
||||
return request, None
|
||||
elif request.guided_choice:
|
||||
try:
|
||||
escaped_choices = (re.sub(r'(["\\])', r'\\\1', c)
|
||||
for c in request.guided_choice)
|
||||
guided_choice = ('root ::= ' +
|
||||
' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
escaped_choices = (re.sub(r'(["\\])', r"\\\1", c) for c in request.guided_choice)
|
||||
guided_choice = "root ::= " + " | ".join(f'"{c}"' for c in escaped_choices)
|
||||
|
||||
Grammar.from_ebnf(guided_choice)
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {str(e)}"
|
||||
err_msg = f"Invalid choice format: {guided_choice}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
|
||||
request.guided_grammar = guided_choice
|
||||
@@ -445,11 +455,12 @@ class XGrammarChecker(BaseChecker):
|
||||
begin=s["begin"],
|
||||
schema=json.dumps(s["schema"]),
|
||||
end=s["end"],
|
||||
) for s in structural_tag["structures"]
|
||||
)
|
||||
for s in structural_tag["structures"]
|
||||
]
|
||||
Grammar.from_structural_tag(tags, structural_tag["triggers"])
|
||||
except RuntimeError as e:
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {str(e)}"
|
||||
err_msg = f"Invalid structural_tag format: {structural_tag}, error message: {e!s}"
|
||||
return request, err_msg
|
||||
return request, None
|
||||
else:
|
||||
|
Reference in New Issue
Block a user