mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.utils import llm_logger
|
||||
|
||||
|
||||
class LogitsProcessorBase:
|
||||
"""
|
||||
Abstract base class for logits processors in guided decoding.
|
||||
|
||||
This class defines the interface for logits processors that modify token probabilities
|
||||
during generation to enforce schema constraints. Subclasses should implement all
|
||||
abstract methods to provide specific constraint enforcement logic.
|
||||
|
||||
Attributes:
|
||||
None (all state should be managed by subclasses)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def fill_token_bitmask(self, token_bitmask, idx):
|
||||
"""
|
||||
Fill the vocabulary mask.
|
||||
|
||||
Args:
|
||||
token_bitmask (tensor): The vocabulary mask tensor.
|
||||
idx (tensor): The tensor index.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_token_mask(self, logits, token_bitmask):
|
||||
"""
|
||||
Apply the vocabulary mask to logits.
|
||||
|
||||
Args:
|
||||
logits (tensor): The logits tensor.
|
||||
token_bitmask (tensor): The vocabulary mask tensor.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def allocate_token_bitmask(self, batch_size, vocab_size):
|
||||
"""
|
||||
Allocate a token bitmask for the given batch size and vocabulary size.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size.
|
||||
vocab_size (int): The vocabulary size.
|
||||
|
||||
Returns:
|
||||
tensor: The allocated token bitmask.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def accept_token(self, token):
|
||||
"""
|
||||
Accept tokens based on the token bitmask
|
||||
|
||||
Args:
|
||||
token (int): The token id.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_terminated(self):
|
||||
"""
|
||||
Check if the processor has been terminated.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the matcher state.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Create a copy of the backend instance.
|
||||
|
||||
Returns:
|
||||
BackendBase: A copy of the backend instance.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class BackendBase:
|
||||
"""
|
||||
Abstract base class for guided decoding backends.
|
||||
|
||||
This class provides the core infrastructure for managing schema processors and
|
||||
their caching. It handles:
|
||||
- Processor creation and caching
|
||||
- Tokenizer initialization
|
||||
- Thread pool management for async operations
|
||||
|
||||
Attributes:
|
||||
cache (dict): Cache of schema processors
|
||||
fd_config (FDConfig): FastDeploy configuration
|
||||
executor (ThreadPoolExecutor): Thread pool for async operations
|
||||
max_cache_size (int): Maximum number of processors to cache
|
||||
hf_tokenizer: HuggingFace tokenizer instance
|
||||
"""
|
||||
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
self.cache = {}
|
||||
self.fd_config = fd_config
|
||||
self.executor = ThreadPoolExecutor()
|
||||
self.max_cache_size = 2048
|
||||
|
||||
self.hf_tokenizer = self._get_tokenizer_hf()
|
||||
|
||||
def _create_processor(self):
|
||||
"""
|
||||
Create a specific logits processor instance.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _json_processor(self, schemata):
|
||||
"""
|
||||
Process JSON schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _regex_processor(self, schemata):
|
||||
"""
|
||||
Process regular expression schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _grammar_processor(self, schemata):
|
||||
"""
|
||||
Process grammar schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _structural_tag_processor(self, schemata):
|
||||
"""
|
||||
Process structural tag schemata.
|
||||
|
||||
Args:
|
||||
schemata (str): The schemata string.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method should be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _unsupported_processor_type(self, key_type, schemata):
|
||||
"""
|
||||
Process unsupported type.
|
||||
|
||||
Args:
|
||||
key_type (str): The key type string.
|
||||
schemata (str): The schemata string.
|
||||
"""
|
||||
raise Exception(f"Unsupported processor type {key_type}.")
|
||||
|
||||
def _init_logits_processor(
|
||||
self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
||||
"""
|
||||
init logits processor by type and schemata.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
|
||||
Returns:
|
||||
LogitsProcessorBase: Initialized logits processor instance
|
||||
|
||||
Raises:
|
||||
ValueError: If processor type is not supported
|
||||
"""
|
||||
key_type, schemata = schemata_key
|
||||
if key_type == "json":
|
||||
return self._json_processor(schemata)
|
||||
elif key_type == "regex":
|
||||
return self._regex_processor(schemata)
|
||||
elif key_type == "grammar":
|
||||
return self._grammar_processor(schemata)
|
||||
elif key_type == "structural_tag":
|
||||
return self._structural_tag_processor(schemata)
|
||||
else:
|
||||
llm_logger.error(f"Unsupported processor type {key_type}.")
|
||||
return None
|
||||
|
||||
def get_logits_processor(
|
||||
self,
|
||||
schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
||||
"""
|
||||
get logits processor by key from cache or create new one.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
|
||||
Returns:
|
||||
tuple[LogitsProcessorBase, bool]: Tuple containing:
|
||||
- LogitsProcessorBase: The logits processor instance
|
||||
- bool: True if processor was from cache, False if newly created
|
||||
"""
|
||||
value = self.cache.get(schemata_key, None)
|
||||
if value:
|
||||
return value.copy(), True
|
||||
value = self.executor.submit(self._init_logits_processor, schemata_key)
|
||||
return value, False
|
||||
|
||||
def _get_tokenizer_hf(self):
|
||||
"""
|
||||
Initialize and return a HuggingFace tokenizer instance.
|
||||
|
||||
This method handles special cases for Ernie models and falls back to standard
|
||||
AutoTokenizer for other models. It also ensures fast tokenizer is used when possible.
|
||||
|
||||
Returns:
|
||||
Tokenizer: Initialized HuggingFace tokenizer instance
|
||||
|
||||
Raises:
|
||||
Exception: If tokenizer initialization fails
|
||||
"""
|
||||
try:
|
||||
architectures = self.fd_config.model_config.architectures
|
||||
if "Ernie4_5_MoeForCausalLM" not in architectures \
|
||||
and "Ernie4_5_ForCausalLM" not in architectures:
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
__slow_tokenizer=tokenizer)
|
||||
else:
|
||||
from fastdeploy.model_executor.guided_decoding.ernie_tokenizer import \
|
||||
ErnieBotTokenizer
|
||||
|
||||
vocab_file_names = [
|
||||
"tokenizer.model", "spm.model", "ernie_token_100k.model"
|
||||
]
|
||||
for i in range(len(vocab_file_names)):
|
||||
if os.path.exists(
|
||||
os.path.join(
|
||||
self.fd_config.parallel_config.
|
||||
model_name_or_path, vocab_file_names[i])):
|
||||
ErnieBotTokenizer.vocab_files_names[
|
||||
"vocab_file"] = vocab_file_names[i]
|
||||
break
|
||||
|
||||
tokenizer = ErnieBotTokenizer.from_pretrained(
|
||||
self.fd_config.parallel_config.model_name_or_path)
|
||||
|
||||
return tokenizer
|
||||
except Exception as e:
|
||||
raise Exception(f"Fail to initialize hf tokenizer: {e}")
|
||||
|
||||
def add_cache(self, schemata_key: tuple[str, str],
|
||||
processor: LogitsProcessorBase) -> None:
|
||||
"""
|
||||
add logits processor to cache.
|
||||
|
||||
Args:
|
||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||
processor (LogitsProcessorBase): Logits processor instance to cache
|
||||
|
||||
Returns:
|
||||
None: No return value
|
||||
"""
|
||||
if len(self.cache) >= self.max_cache_size:
|
||||
return
|
||||
self.cache[schemata_key] = processor.copy()
|
||||
|
||||
|
||||
class BaseChecker:
|
||||
"""
|
||||
Abstract base class for schema checkers.
|
||||
|
||||
This class defines the interface for validating and formatting schemas
|
||||
before they are used by logits processors. Subclasses should implement
|
||||
schema-specific validation and formatting logic.
|
||||
|
||||
Attributes:
|
||||
None (all state should be managed by subclasses)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def schema_format(self, request: Request):
|
||||
"""
|
||||
format schema to backend specific format.
|
||||
Args:
|
||||
request (Request): request object.
|
||||
|
||||
Returns:
|
||||
request (Request): request object with formatted schema.
|
||||
"""
|
||||
raise NotImplementedError()
|
Reference in New Issue
Block a user