mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] mm and thinking model support structred output (#2749)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* mm support structured output * update code * update code * update format * update code * update code * add enable_thinking default * update code * add structured_outputs test case * add ci install xgrammar * add ci timeout time * update test for structured_outputs * update code * add error traceback info * update error msg * update structred output code * update code * update code * update config * update torch version --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -15,8 +15,13 @@
|
||||
"""
|
||||
|
||||
# from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
BackendBase,
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
)
|
||||
|
||||
__all__ = ["get_guided_backend", "schema_checker"]
|
||||
__all__ = ["get_guided_backend", "schema_checker", "LogitsProcessorBase", "BackendBase", "BaseChecker"]
|
||||
|
||||
|
||||
def get_guided_backend(
|
||||
|
@@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastdeploy.config import ErnieArchitectures, FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.reasoning import ReasoningParserManager
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
@@ -35,8 +36,9 @@ class LogitsProcessorBase:
|
||||
None (all state should be managed by subclasses)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, enable_reasoning):
|
||||
self.reasoning_ended = False
|
||||
self.enable_reasoning = enable_reasoning
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask, idx):
|
||||
"""
|
||||
@@ -137,8 +139,14 @@ class BackendBase:
|
||||
self.fd_config = fd_config
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.max_cache_size = 2048
|
||||
self.reasoning_parser = None
|
||||
|
||||
self.hf_tokenizer = self._get_tokenizer_hf()
|
||||
if self.fd_config.model_config.reasoning_parser:
|
||||
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
|
||||
self.fd_config.model_config.reasoning_parser
|
||||
)
|
||||
self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
|
||||
|
||||
def _create_processor(self):
|
||||
"""
|
||||
@@ -149,70 +157,88 @@ class BackendBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _json_processor(self, schemata):
|
||||
def _json_processor(self, schemata, enable_thinking=False):
|
||||
"""
|
||||
Process JSON schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
enable_thinking (bool): Whether to enable thinking mode.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _regex_processor(self, schemata):
|
||||
def _regex_processor(self, schemata, enable_thinking=False):
|
||||
"""
|
||||
Process regular expression schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
enable_thinking (bool): Whether to enable thinking mode.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _grammar_processor(self, schemata):
|
||||
def _grammar_processor(self, schemata, enable_thinking=False):
|
||||
"""
|
||||
Process grammar schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
enable_thinking (bool): Whether to enable thinking mode.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _structural_tag_processor(self, schemata):
|
||||
def _structural_tag_processor(self, schemata, enable_thinking=False):
|
||||
"""
|
||||
Process structural tag schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
enable_thinking (bool): Whether to enable thinking mode.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _unsupported_processor_type(self, key_type, schemata):
|
||||
def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
|
||||
"""
|
||||
Process unsupported type.
|
||||
|
||||
Args:
|
||||
key_type (str): The key type string.
|
||||
schemata (str): The schemata string.
|
||||
enable_thinking (bool): Whether to enable thinking mode.
|
||||
"""
|
||||
raise Exception(f"Unsupported processor type {key_type}.")
|
||||
|
||||
def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
||||
def get_reasoning_parser(self):
|
||||
"""
|
||||
Get reasoning parser object.
|
||||
Returns:
|
||||
ReasoningParser: Reasoning parser object or None
|
||||
"""
|
||||
return self.reasoning_parser
|
||||
|
||||
def _init_logits_processor(
|
||||
self,
|
||||
schemata_key: tuple[str, str],
|
||||
enable_thinking: bool = False,
|
||||
) -> LogitsProcessorBase:
|
||||
"""
|
||||
init logits processor by type and schemata.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
enable_thinking (bool): Whether to enable thinking step
|
||||
|
||||
Returns:
|
||||
LogitsProcessorBase: Initialized logits processor instance
|
||||
@@ -222,18 +248,22 @@ class BackendBase:
|
||||
"""
|
||||
key_type, schemata = schemata_key
|
||||
if key_type == "json":
|
||||
return self._json_processor(schemata)
|
||||
return self._json_processor(schemata, enable_thinking)
|
||||
elif key_type == "regex":
|
||||
return self._regex_processor(schemata)
|
||||
return self._regex_processor(schemata, enable_thinking)
|
||||
elif key_type == "grammar":
|
||||
return self._grammar_processor(schemata)
|
||||
return self._grammar_processor(schemata, enable_thinking)
|
||||
elif key_type == "structural_tag":
|
||||
return self._structural_tag_processor(schemata)
|
||||
return self._structural_tag_processor(schemata, enable_thinking)
|
||||
else:
|
||||
llm_logger.error(f"Unsupported processor type {key_type}.")
|
||||
return None
|
||||
|
||||
def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
||||
def get_logits_processor(
|
||||
self,
|
||||
schemata_key: tuple[str, str],
|
||||
enable_thinking: bool = False,
|
||||
) -> tuple[LogitsProcessorBase, bool]:
|
||||
"""
|
||||
get logits processor by key from cache or create new one.
|
||||
|
||||
@@ -247,8 +277,10 @@ class BackendBase:
|
||||
"""
|
||||
value = self.cache.get(schemata_key, None)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_logits_processor, schemata_key)
|
||||
value_copy = value.copy()
|
||||
value_copy.enable_reasoning = enable_thinking
|
||||
return value_copy, True
|
||||
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
|
||||
return value, False
|
||||
|
||||
def _get_tokenizer_hf(self):
|
||||
@@ -267,7 +299,6 @@ class BackendBase:
|
||||
try:
|
||||
architectures = self.fd_config.model_config.architectures
|
||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
except ImportError as err:
|
||||
raise ImportError("Triton is not installed") from err
|
||||
|
||||
|
||||
@triton.jit
|
||||
def apply_token_bitmask_inplace_kernel(
|
||||
logits_ptr,
|
||||
bitmask_ptr,
|
||||
indices_ptr,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits_strides,
|
||||
bitmask_strides,
|
||||
NUM_SMS: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Triton kernel for in-place logits masking using bitwise compression.
|
||||
|
||||
Processes logits tensor in blocks, applying bitmask to restrict vocabulary access.
|
||||
Masked positions are set to -inf to ensure zero probability during sampling.
|
||||
|
||||
Note:
|
||||
- Bitmask uses 32:1 compression (1 bit per vocabulary token)
|
||||
- Optimized for GPU parallel processing with configurable block size
|
||||
"""
|
||||
pid = tl.program_id(0)
|
||||
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
||||
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
||||
row_id = work_id // num_blocks
|
||||
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
||||
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
||||
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
||||
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||
vocab_mask = offsets < vocab_size
|
||||
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
||||
packed_bitmask = tl.load(bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask)
|
||||
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||
|
||||
tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
|
||||
|
||||
|
||||
def apply_token_bitmask_inplace_triton(
|
||||
logits: paddle.Tensor,
|
||||
bitmask: paddle.Tensor,
|
||||
vocab_size: Optional[int] = None,
|
||||
indices: Optional[List[int]] = None,
|
||||
):
|
||||
"""Applies vocabulary mask to logits tensor using Triton GPU kernel.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor of shape [batch_size, vocab_size]
|
||||
bitmask: Compressed mask tensor (int32) where each bit represents a token
|
||||
vocab_size: Optional explicit vocabulary size (defaults to auto-detected)
|
||||
indices: Optional list of batch indices to apply mask to
|
||||
|
||||
Note:
|
||||
Requires CUDA GPU with Triton support
|
||||
Bitmask must be int32 tensor with shape [batch_size, ceil(vocab_size/32)]
|
||||
"""
|
||||
NUM_SMS = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||
BLOCK_SIZE = 4096
|
||||
|
||||
assert bitmask.dtype == paddle.int32, "bitmask must be of type int32"
|
||||
|
||||
detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32)
|
||||
if vocab_size is None:
|
||||
vocab_size = detected_vocab_size
|
||||
else:
|
||||
assert (
|
||||
vocab_size <= detected_vocab_size
|
||||
), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}"
|
||||
|
||||
num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1
|
||||
|
||||
if indices is not None:
|
||||
indices = paddle.to_tensor(indices, dtype=paddle.int32, place=logits.place)
|
||||
|
||||
grid = (NUM_SMS,)
|
||||
|
||||
apply_token_bitmask_inplace_kernel[grid](
|
||||
logits,
|
||||
bitmask,
|
||||
indices,
|
||||
num_rows,
|
||||
vocab_size,
|
||||
logits.shape[-1],
|
||||
bitmask.shape[-1],
|
||||
NUM_SMS,
|
||||
BLOCK_SIZE,
|
||||
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
||||
num_stages=3,
|
||||
)
|
@@ -24,7 +24,7 @@ import torch
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||
from fastdeploy.model_executor.guided_decoding import (
|
||||
BackendBase,
|
||||
BaseChecker,
|
||||
LogitsProcessorBase,
|
||||
@@ -57,7 +57,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
||||
vocab_size (int): Size of the vocabulary
|
||||
batch_size (int): Batch size for processing
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
||||
@@ -71,13 +70,12 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
vocab_size: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
splitwise_role: str = "mixed",
|
||||
enable_thinking: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(enable_reasoning=enable_thinking)
|
||||
self.max_rollback_tokens = 200
|
||||
self.vocab_size = vocab_size
|
||||
self.batch_size = batch_size
|
||||
self.splitwise_role = splitwise_role
|
||||
self.compiled_grammar = compiled_grammar
|
||||
self.terminate_without_stop_token = terminate_without_stop_token
|
||||
self.override_stop_tokens = override_stop_tokens
|
||||
@@ -188,7 +186,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
||||
override_stop_tokens=self.override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
)
|
||||
|
||||
|
||||
@@ -203,7 +200,6 @@ class XGrammarBackend(BackendBase):
|
||||
vocab_size (int): Size of the vocabulary from config
|
||||
batch_size (int): Maximum batch size from config
|
||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||
splitwise_role (str): Role for splitwise processing
|
||||
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
||||
"""
|
||||
|
||||
@@ -217,7 +213,6 @@ class XGrammarBackend(BackendBase):
|
||||
self.batch_size = fd_config.parallel_config.max_num_seqs
|
||||
|
||||
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
||||
|
||||
try:
|
||||
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||
@@ -230,6 +225,7 @@ class XGrammarBackend(BackendBase):
|
||||
compiled_grammar: CompiledGrammar,
|
||||
terminate_without_stop_token: bool = False,
|
||||
override_stop_tokens: Optional[List[int]] = None,
|
||||
enable_thinking: bool = False,
|
||||
) -> XGrammarProcessor:
|
||||
"""
|
||||
Create a logits processor instance for the given compiled grammar.
|
||||
@@ -238,6 +234,7 @@ class XGrammarBackend(BackendBase):
|
||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
||||
enable_thinking (bool): Whether to enable thinking mode
|
||||
|
||||
Returns:
|
||||
XGrammarProcessor: Configured grammar processor instance
|
||||
@@ -248,15 +245,16 @@ class XGrammarBackend(BackendBase):
|
||||
override_stop_tokens=override_stop_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
batch_size=self.batch_size,
|
||||
splitwise_role=self.splitwise_role,
|
||||
enable_thinking=enable_thinking,
|
||||
)
|
||||
|
||||
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile JSON schema into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): JSON schema string to compile
|
||||
enable_thinking (bool): Whether to enable thinking mode
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
@@ -266,14 +264,15 @@ class XGrammarBackend(BackendBase):
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||
|
||||
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile regex pattern into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Regex pattern string to compile
|
||||
enable_thinking (bool): Whether to enable thinking mode
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
@@ -283,14 +282,15 @@ class XGrammarBackend(BackendBase):
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||
|
||||
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile grammar (EBNF) into a grammar processor.
|
||||
|
||||
Args:
|
||||
schemata (str): Grammar string in EBNF format
|
||||
enable_thinking (bool): Whether to enable thinking mode
|
||||
|
||||
Returns:
|
||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||
@@ -300,9 +300,9 @@ class XGrammarBackend(BackendBase):
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||
|
||||
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
||||
def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||
"""
|
||||
Compile structural tags into a grammar processor.
|
||||
|
||||
@@ -327,7 +327,7 @@ class XGrammarBackend(BackendBase):
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
|
||||
return None
|
||||
return self._create_processor(compiled_grammar)
|
||||
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||
|
||||
|
||||
class XGrammarChecker(BaseChecker):
|
||||
|
Reference in New Issue
Block a user