mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] Guided Decoding add LLguidance backend (#5124)
* llguidance * add requirements_guided_decoding.txt and doc * fix test_guidance_*.py * fix test_guidance_*.py && mv * fix llguidance choice * test_guidance_* * rm lazy loader --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,15 @@ def get_guided_backend(
|
||||
fd_config=fd_config,
|
||||
**kwargs,
|
||||
)
|
||||
elif fd_config.structured_outputs_config.guided_decoding_backend.lower() == "guidance":
|
||||
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
|
||||
LLGuidanceBackend,
|
||||
)
|
||||
|
||||
return LLGuidanceBackend(
|
||||
fd_config=fd_config,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Get unsupported backend {fd_config.structured_outputs_config.guided_decoding_backend},"
|
||||
@@ -77,5 +86,11 @@ def schema_checker(backend_name: str, **kwargs):
|
||||
)
|
||||
|
||||
return XGrammarChecker(**kwargs)
|
||||
elif backend_name.lower() == "guidance":
|
||||
from fastdeploy.model_executor.guided_decoding.guidance_backend import (
|
||||
LLGuidanceChecker,
|
||||
)
|
||||
|
||||
return LLGuidanceChecker(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Get unsupported backend {backend_name}, please check your configuration.")
|
||||
|
||||
@@ -294,7 +294,12 @@ class BackendBase:
|
||||
"""
|
||||
try:
|
||||
architectures = self.fd_config.model_config.architectures
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
is_guidance_backend = (
|
||||
self.fd_config.structured_outputs_config is not None
|
||||
and self.fd_config.structured_outputs_config.guided_decoding_backend is not None
|
||||
and self.fd_config.structured_outputs_config.guided_decoding_backend == "guidance"
|
||||
)
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures) or is_guidance_backend:
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
314
fastdeploy/model_executor/guided_decoding/guidance_backend.py
Normal file
314
fastdeploy/model_executor/guided_decoding/guidance_backend.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import traceback
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import llguidance
|
||||
import llguidance.hf
|
||||
import llguidance.torch
|
||||
import torch
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.envs import FD_GUIDANCE_DISABLE_ADDITIONAL, FD_LLGUIDANCE_LOG_LEVEL
|
||||
from fastdeploy.model_executor.guided_decoding import (
|
||||
BackendBase,
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class LLGuidanceProcessor(LogitsProcessorBase):
|
||||
"""
|
||||
LLGuidance-specific implementation of LogitsProcessorBase.
|
||||
|
||||
This processor enforces grammar constraints during token generation using llguidance.
|
||||
It manages the grammar matching state and applies token masks to logits.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ll_matcher: llguidance.LLMatcher,
|
||||
ll_tokenizer: llguidance.LLTokenizer,
|
||||
serialized_grammar: str,
|
||||
vocab_size: int,
|
||||
batch_size: int,
|
||||
enable_thinking: bool = False,
|
||||
):
|
||||
super().__init__(enable_reasoning=enable_thinking)
|
||||
self.matcher = ll_matcher
|
||||
self.ll_tokenizer = ll_tokenizer
|
||||
self.serialized_grammar = serialized_grammar
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.is_terminated: bool = False
|
||||
self._printed_error: bool = False
|
||||
|
||||
def _check_error(self):
|
||||
"""Checks for and logs any errors from the LLMatcher."""
|
||||
if not self._printed_error:
|
||||
err = self.matcher.get_error()
|
||||
if err:
|
||||
self._printed_error = True
|
||||
llm_logger.warning(f"LLGuidance Matcher error: {err}")
|
||||
|
||||
def allocate_token_bitmask(self) -> torch.Tensor:
|
||||
"""
|
||||
Allocate a token bitmask tensor for grammar constraints.
|
||||
"""
|
||||
return llguidance.torch.allocate_token_bitmask(self.batch_size, self.vocab_size)
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask: torch.Tensor, idx: int) -> None:
|
||||
"""
|
||||
Fill the token bitmask with allowed tokens for the given index.
|
||||
This will automatically provide an EOS mask if the matcher is stopped.
|
||||
"""
|
||||
llguidance.torch.fill_next_token_bitmask(self.matcher, token_bitmask, idx)
|
||||
self._check_error()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset the grammar matcher state to initial conditions.
|
||||
"""
|
||||
self.matcher.reset()
|
||||
self.is_terminated = False
|
||||
self._printed_error = False
|
||||
self._check_error()
|
||||
|
||||
def accept_token(self, token: int) -> bool:
|
||||
"""
|
||||
Validate and accept a generated token against the grammar constraints.
|
||||
Returns True if the token is accepted, False otherwise.
|
||||
"""
|
||||
if self.is_terminated:
|
||||
return False
|
||||
if self.ll_tokenizer.eos_token == token:
|
||||
self.is_terminated = True
|
||||
return True
|
||||
|
||||
result = self.matcher.consume_tokens([token])
|
||||
self._check_error()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LLGuidanceBackend(BackendBase):
|
||||
"""
|
||||
LLGuidance-specific implementation of BackendBase.
|
||||
|
||||
This backend handles the compilation of various schema types (JSON, regex, etc.)
|
||||
into LLGuidance processors.
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, **kwargs):
|
||||
super().__init__(fd_config=fd_config)
|
||||
self.vocab_size = fd_config.model_config.vocab_size
|
||||
self.batch_size = fd_config.scheduler_config.max_num_seqs
|
||||
self.any_whitespace = not fd_config.structured_outputs_config.disable_any_whitespace
|
||||
|
||||
llm_logger.info(f"LLGuidanceBackend vocab_size={self.vocab_size} batch_size={self.batch_size}")
|
||||
try:
|
||||
self.ll_tokenizer = llguidance.hf.from_tokenizer(self.hf_tokenizer, self.vocab_size)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize llguidance tokenizer from HuggingFace tokenizer: {e} {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _create_processor(
|
||||
self,
|
||||
compiled_grammar: str,
|
||||
enable_thinking: bool = False,
|
||||
) -> Optional[LLGuidanceProcessor]:
|
||||
"""
|
||||
Create a logits processor instance for the given grammar schemata.
|
||||
"""
|
||||
try:
|
||||
|
||||
ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
compiled_grammar,
|
||||
log_level=FD_LLGUIDANCE_LOG_LEVEL,
|
||||
)
|
||||
|
||||
return LLGuidanceProcessor(
|
||||
ll_matcher=ll_matcher,
|
||||
ll_tokenizer=self.ll_tokenizer,
|
||||
serialized_grammar=compiled_grammar,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to create llguidance processor: {e}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
|
||||
def _json_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]:
|
||||
return self._create_processor(compiled_grammar, enable_thinking)
|
||||
|
||||
def _regex_processor(self, compiled_grammar: str, enable_thinking: bool = False) -> Optional[LLGuidanceProcessor]:
|
||||
return self._create_processor(compiled_grammar, enable_thinking)
|
||||
|
||||
def _grammar_processor(
|
||||
self, compiled_grammar: str, enable_thinking: bool = False
|
||||
) -> Optional[LLGuidanceProcessor]:
|
||||
return self._create_processor(compiled_grammar, enable_thinking)
|
||||
|
||||
def _structural_tag_processor(
|
||||
self, compiled_grammar: str, enable_thinking: bool = False
|
||||
) -> Optional[LLGuidanceProcessor]:
|
||||
return self._create_processor(compiled_grammar, enable_thinking)
|
||||
|
||||
|
||||
def _walk_json_for_additional_properties(data: object):
|
||||
if isinstance(data, dict):
|
||||
for value in data.values():
|
||||
_walk_json_for_additional_properties(value)
|
||||
if "additionalProperties" not in data and ("properties" in data or "patternProperties" in data):
|
||||
data["additionalProperties"] = False
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
_walk_json_for_additional_properties(item)
|
||||
|
||||
|
||||
def process_for_additional_properties(guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
if isinstance(guide_json, str):
|
||||
guide_json_obj = json.loads(guide_json)
|
||||
else:
|
||||
# copy for modifications
|
||||
guide_json_obj = copy.deepcopy(guide_json)
|
||||
_walk_json_for_additional_properties(guide_json_obj)
|
||||
return guide_json_obj
|
||||
|
||||
|
||||
class LLGuidanceChecker(BaseChecker):
|
||||
"""
|
||||
LLGuidance-specific implementation of BaseChecker.
|
||||
|
||||
This checker validates various schema types for compatibility with the
|
||||
llguidance library before processing.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
# Although the backend handles serialization, we can perform a quick
|
||||
# static check here without a full tokenizer.
|
||||
self.any_whitespace = not kwargs.get("disable_any_whitespace", False)
|
||||
self.disable_additional_properties = FD_GUIDANCE_DISABLE_ADDITIONAL
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
|
||||
def serialize_guidance_grammar(self, request: Request):
|
||||
def _process_schema(
|
||||
grammar_spec: Union[str, dict[str, Any]],
|
||||
) -> str:
|
||||
if self.disable_additional_properties:
|
||||
grammar_spec = process_for_additional_properties(grammar_spec)
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
grammar_spec,
|
||||
defaults={
|
||||
"whitespace_flexible": self.any_whitespace,
|
||||
},
|
||||
)
|
||||
|
||||
if request.guided_json:
|
||||
if isinstance(request.guided_json, dict):
|
||||
guided_json = json.dumps(request.guided_json)
|
||||
else:
|
||||
guided_json = request.guided_json
|
||||
return _process_schema(guided_json)
|
||||
elif request.guided_json_object:
|
||||
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
defaults={
|
||||
"whitespace_flexible": self.any_whitespace,
|
||||
},
|
||||
)
|
||||
|
||||
if request.structural_tag:
|
||||
if isinstance(request.structural_tag, str):
|
||||
s_tag = json.loads(request.structural_tag)
|
||||
else:
|
||||
s_tag = request.structural_tag
|
||||
triggers: list[str] = s_tag["triggers"]
|
||||
tags: list[llguidance.StructTag] = []
|
||||
for s in s_tag["structures"]:
|
||||
begin: str = s["begin"]
|
||||
trig = next((t for t in triggers if begin.startswith(t)), None)
|
||||
if trig is None:
|
||||
raise ValueError(f"Trigger {begin} not found in triggers {triggers}")
|
||||
tags.append(
|
||||
llguidance.StructTag(
|
||||
trigger=trig,
|
||||
begin=s["begin"],
|
||||
grammar=_process_schema(s["schema"]),
|
||||
end=s["end"],
|
||||
)
|
||||
)
|
||||
if not tags:
|
||||
raise ValueError("No structural tags found in the grammar spec.")
|
||||
return llguidance.StructTag.to_grammar(tags)
|
||||
|
||||
if request.guided_regex:
|
||||
tp = "regex"
|
||||
grammar_spec = request.guided_regex
|
||||
elif request.guided_choice:
|
||||
tp = "choice"
|
||||
grammar_spec = request.guided_choice
|
||||
elif request.guided_grammar:
|
||||
tp = "grammar"
|
||||
grammar_spec = request.guided_grammar
|
||||
else:
|
||||
llm_logger.error("Validation should have already occurred. " "Please file an issue.")
|
||||
raise ValueError("grammar is not of valid supported types. ")
|
||||
return llguidance.grammar_from(tp, grammar_spec)
|
||||
|
||||
def schema_format(self, request: Request) -> Tuple[Request, Optional[str]]:
|
||||
"""
|
||||
Validates and formats the schema for the LLGuidance backend.
|
||||
"""
|
||||
try:
|
||||
guidance_grm = self.serialize_guidance_grammar(request)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, None)
|
||||
if err:
|
||||
raise ValueError(f"Grammar error: {err}")
|
||||
else:
|
||||
llm_logger.info(f"valid schema_format {guidance_grm} {request}")
|
||||
if request.guided_regex:
|
||||
request.guided_regex = guidance_grm
|
||||
elif request.guided_choice:
|
||||
request.guided_grammar = guidance_grm
|
||||
request.guided_choice = None
|
||||
elif request.guided_grammar:
|
||||
request.guided_grammar = guidance_grm
|
||||
elif request.guided_json:
|
||||
request.guided_json = guidance_grm
|
||||
|
||||
except (ValueError, TypeError, json.JSONDecodeError) as e:
|
||||
err_msg = f"Invalid format for guided decoding: {e!s} request={request}"
|
||||
return request, err_msg
|
||||
|
||||
except Exception as e:
|
||||
err_msg = f"An unexpected error occurred during schema validation: {e!s}"
|
||||
return request, err_msg
|
||||
|
||||
return request, None
|
||||
@@ -73,7 +73,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
enable_thinking: bool = False,
|
||||
):
|
||||
super().__init__(enable_reasoning=enable_thinking)
|
||||
self.max_rollback_tokens = 200
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.compiled_grammar = compiled_grammar
|
||||
@@ -82,7 +81,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
|
||||
self.matcher = GrammarMatcher(
|
||||
compiled_grammar=compiled_grammar,
|
||||
max_rollback_tokens=self.max_rollback_tokens,
|
||||
terminate_without_stop_token=terminate_without_stop_token,
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user