diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md
index 40e177c1c..f7ee424cb 100644
--- a/docs/features/structured_outputs.md
+++ b/docs/features/structured_outputs.md
@@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong
Address: No.1 Century Avenue, Pudong New Area, Shanghai
Height: 468
```
+
+### Offline Inference
+
+Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference:
+
+```python
+json: Optional[Union[str, dict]] = None
+regex: Optional[str] = None
+choice: Optional[List[str]] = None
+grammar: Optional[str] = None
+json_object: Optional[bool] = None
+structural_tag: Optional[str] = None
+```
+
+The following example demonstrates how to use offline inference to generate a structured json:
+
+```python
+from fastdeploy import LLM, SamplingParams
+from fastdeploy.engine.sampling_params import GuidedDecodingParams
+from pydantic import BaseModel
+from enum import Enum
+
+class BookType(str, Enum):
+ romance = "Romance"
+ historical = "Historical"
+ adventure = "Adventure"
+ mystery = "Mystery"
+ dystopian = "Dystopian"
+
+class BookDescription(BaseModel):
+ author: str
+ title: str
+ genre: BookType
+
+# Constrained decoding parameters
+guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
+
+# Sampling parameters
+sampling_params = SamplingParams(
+ top_p=0.95,
+ max_tokens=6400,
+ guided_decoding=guided_decoding_params,
+)
+
+# Load model
+llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
+
+outputs = llm.generate(
+ prompts="Generate a JSON describing a literary work, including author, title and book type.",
+ sampling_params=sampling_params,
+)
+
+# Output results
+for output in outputs:
+ print(output.outputs.text)
+```
+
+Output:
+
+```
+{"author": "George Orwell", "title": "1984", "genre": "Dystopian"}
+```
diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md
index ce33f1232..cafda804c 100644
--- a/docs/zh/features/structured_outputs.md
+++ b/docs/zh/features/structured_outputs.md
@@ -330,3 +330,67 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪
地址: 上海市浦东新区世纪大道1号
高度: 468
```
+
+### 离线推理
+
+离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理:
+
+```python
+json: Optional[Union[str, dict]] = None
+regex: Optional[str] = None
+choice: Optional[List[str]] = None
+grammar: Optional[str] = None
+json_object: Optional[bool] = None
+structural_tag: Optional[str] = None
+```
+
+以下示例展示了如何使用离线推理生成一个结构化的 json :
+
+```python
+
+from fastdeploy import LLM, SamplingParams
+from fastdeploy.engine.sampling_params import GuidedDecodingParams
+from pydantic import BaseModel
+from enum import Enum
+
+class BookType(str, Enum):
+ romance = "Romance"
+ historical = "Historical"
+ adventure = "Adventure"
+ mystery = "Mystery"
+ dystopian = "Dystopian"
+
+class BookDescription(BaseModel):
+ author: str
+ title: str
+ genre: BookType
+
+# Constrained decoding parameters
+guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
+
+# Sampling parameters
+sampling_params = SamplingParams(
+ top_p=0.95,
+ max_tokens=6400,
+ guided_decoding=guided_decoding_params,
+)
+
+# Load model
+llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
+
+outputs = llm.generate(
+ prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。",
+ sampling_params=sampling_params,
+)
+
+# Output results
+for output in outputs:
+ print(output.outputs.text)
+
+```
+
+输出
+
+```
+{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"}
+```
diff --git a/fastdeploy/config.py b/fastdeploy/config.py
index c247b7880..c52b2530b 100644
--- a/fastdeploy/config.py
+++ b/fastdeploy/config.py
@@ -127,12 +127,13 @@ class ModelConfig:
self.redundant_experts_num = 0
self.seed = 0
self.quantization = None
+ self.reasoning_parser = None
self.pad_token_id: int = -1
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
for key, value in args.items():
- if hasattr(self, key):
+ if hasattr(self, key) and value != "None":
setattr(self, key, value)
assert self.model != ""
@@ -1249,7 +1250,8 @@ class FDConfig:
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
if self.guided_decoding_backend == "auto":
- if self.model_config.enable_mm:
+ if current_platform.is_xpu() or self.speculative_config.method is not None:
+ logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
self.guided_decoding_backend = "off"
else:
self.guided_decoding_backend = "xgrammar"
@@ -1319,12 +1321,10 @@ class FDConfig:
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
if self.guided_decoding_backend != "off":
- # TODO: mm support guided_decoding
- assert (
- self.model_config.enable_mm is False
- ), "Multimodal model currently do not support guided_decoding"
-
# TODO: speculative decoding support guided_decoding
+ assert (
+ self.speculative_config.method is None
+ ), "speculative decoding currently do not support guided_decoding"
# TODO: xpu support guided_decoding
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
@@ -1335,6 +1335,7 @@ class FDConfig:
raise Exception(
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
)
+
if self.scheduler_config is not None:
self.scheduler_config.check()
diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py
index c508f4ee5..a71a2df61 100644
--- a/fastdeploy/engine/engine.py
+++ b/fastdeploy/engine/engine.py
@@ -178,6 +178,22 @@ class LLMEngine:
# _insert_task_to_worker moved to CommonEngine
+ def _has_guided_input(self, request):
+ """
+ Check if the request has any guided input.
+ """
+ return any(
+ x is not None
+ for x in (
+ request.guided_json,
+ request.guided_regex,
+ request.guided_choice,
+ request.structural_tag,
+ request.guided_grammar,
+ request.guided_json_object,
+ )
+ )
+
def add_requests(self, task, sampling_params=None, **kwargs):
"""
Add a new request to the queue.
@@ -249,8 +265,15 @@ class LLMEngine:
llm_logger.error(error_msg)
raise EngineError(error_msg, error_code=400)
- if self.engine.guided_decoding_checker is not None:
- request, err_msg = self.engine.guided_decoding_checker.schema_format(request)
+ if self._has_guided_input(request):
+ err_msg = None
+ if self.guided_decoding_checker is None:
+ err_msg = (
+ "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
+ )
+ else:
+ request, err_msg = self.guided_decoding_checker.schema_format(request)
+
if err_msg is not None:
llm_logger.error(err_msg)
raise EngineError(err_msg, error_code=400)
@@ -469,6 +492,7 @@ class LLMEngine:
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
f" --load_strategy {self.cfg.load_config.load_strategy}"
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
+ f" --reasoning_parser {self.cfg.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
f" --ips {ips}"
diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py
index 04a2276af..c1431c42f 100644
--- a/fastdeploy/engine/request.py
+++ b/fastdeploy/engine/request.py
@@ -263,13 +263,11 @@ class Request:
setattr(self, key, value)
def __repr__(self) -> str:
- return (
- f"Request(request_id={self.request_id}, "
- f"prompt={self.prompt!r}, "
- f"prompt_token_ids={self.prompt_token_ids}, "
- f"draft_token_ids={self.draft_token_ids}, "
- f"sampling_params={self.sampling_params})"
- )
+ non_none_fields = []
+ for attr, value in vars(self).items():
+ if value is not None and not attr.startswith("_"):
+ non_none_fields.append(f"{attr}={value!r}")
+ return f"Request({', '.join(non_none_fields)})"
@dataclass(slots=True)
diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py
index 423434857..73231b3be 100644
--- a/fastdeploy/engine/sampling_params.py
+++ b/fastdeploy/engine/sampling_params.py
@@ -100,6 +100,7 @@ class SamplingParams:
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
+ guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
@classmethod
@@ -132,6 +133,7 @@ class SamplingParams:
min_tokens=1,
logprobs=None,
bad_words=None,
+ guided_decoding=None,
bad_words_token_ids=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
@@ -153,6 +155,7 @@ class SamplingParams:
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words,
+ guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
)
@@ -217,3 +220,51 @@ class BeamSearchParams:
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False
+
+
+@dataclass
+class GuidedDecodingParams:
+ """Guided decoding parameters for text generation."""
+
+ json: Optional[Union[str, dict]] = None
+ regex: Optional[str] = None
+ choice: Optional[List[str]] = None
+ grammar: Optional[str] = None
+ json_object: Optional[bool] = None
+ structural_tag: Optional[str] = None
+
+ def to_dict(self):
+ """convert to dict"""
+ key_dict = {
+ "guided_json": self.json,
+ "guided_regex": self.regex,
+ "guided_choice": self.choice,
+ "guided_grammar": self.grammar,
+ "structural_tag": self.structural_tag,
+ "guided_json_object": self.json_object,
+ }
+
+ guided_dict = {}
+ for key, value in key_dict.items():
+ if value is not None:
+ guided_dict[key] = value
+ return guided_dict
+
+ def __post_init__(self):
+ """Verify the arguments."""
+ guided_count = sum(
+ [
+ self.json is not None,
+ self.regex is not None,
+ self.choice is not None,
+ self.grammar is not None,
+ self.json_object is not None,
+ self.structural_tag is not None,
+ ]
+ )
+
+ if guided_count > 1:
+ raise ValueError(
+ "You can only use one kind of guided decoding "
+ "('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
+ )
diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py
index f9537e557..d69068b6f 100644
--- a/fastdeploy/entrypoints/llm.py
+++ b/fastdeploy/entrypoints/llm.py
@@ -295,6 +295,9 @@ class LLM:
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
+ if current_sampling_params.guided_decoding is not None:
+ guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
+ tasks.update(guided_decoding_dict)
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
return req_ids
diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py
index d6ee61199..9336f4a04 100644
--- a/fastdeploy/model_executor/guided_decoding/__init__.py
+++ b/fastdeploy/model_executor/guided_decoding/__init__.py
@@ -15,8 +15,13 @@
"""
# from fastdeploy.config import FDConfig
+from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
+ BackendBase,
+ BaseChecker,
+ LogitsProcessorBase,
+)
-__all__ = ["get_guided_backend", "schema_checker"]
+__all__ = ["get_guided_backend", "schema_checker", "LogitsProcessorBase", "BackendBase", "BaseChecker"]
def get_guided_backend(
diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py
index ea18fbe8b..b9a879e32 100644
--- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py
+++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py
@@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
from fastdeploy.config import ErnieArchitectures, FDConfig
from fastdeploy.engine.request import Request
+from fastdeploy.reasoning import ReasoningParserManager
from fastdeploy.utils import llm_logger
@@ -35,8 +36,9 @@ class LogitsProcessorBase:
None (all state should be managed by subclasses)
"""
- def __init__(self):
- pass
+ def __init__(self, enable_reasoning):
+ self.reasoning_ended = False
+ self.enable_reasoning = enable_reasoning
def fill_token_bitmask(self, token_bitmask, idx):
"""
@@ -137,8 +139,14 @@ class BackendBase:
self.fd_config = fd_config
self.executor = ThreadPoolExecutor()
self.max_cache_size = 2048
+ self.reasoning_parser = None
self.hf_tokenizer = self._get_tokenizer_hf()
+ if self.fd_config.model_config.reasoning_parser:
+ reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
+ self.fd_config.model_config.reasoning_parser
+ )
+ self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
def _create_processor(self):
"""
@@ -149,70 +157,88 @@ class BackendBase:
"""
raise NotImplementedError
- def _json_processor(self, schemata):
+ def _json_processor(self, schemata, enable_thinking=False):
"""
Process JSON schemata.
Args:
schemata (str): The schemata string.
+ enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
- def _regex_processor(self, schemata):
+ def _regex_processor(self, schemata, enable_thinking=False):
"""
Process regular expression schemata.
Args:
schemata (str): The schemata string.
+ enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
- def _grammar_processor(self, schemata):
+ def _grammar_processor(self, schemata, enable_thinking=False):
"""
Process grammar schemata.
Args:
schemata (str): The schemata string.
+ enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
- def _structural_tag_processor(self, schemata):
+ def _structural_tag_processor(self, schemata, enable_thinking=False):
"""
Process structural tag schemata.
Args:
schemata (str): The schemata string.
+ enable_thinking (bool): Whether to enable thinking mode.
Raises:
NotImplementedError: This method should be implemented in subclasses.
"""
raise NotImplementedError
- def _unsupported_processor_type(self, key_type, schemata):
+ def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
"""
Process unsupported type.
Args:
key_type (str): The key type string.
schemata (str): The schemata string.
+ enable_thinking (bool): Whether to enable thinking mode.
"""
raise Exception(f"Unsupported processor type {key_type}.")
- def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
+ def get_reasoning_parser(self):
+ """
+ Get reasoning parser object.
+ Returns:
+ ReasoningParser: Reasoning parser object or None
+ """
+ return self.reasoning_parser
+
+ def _init_logits_processor(
+ self,
+ schemata_key: tuple[str, str],
+ enable_thinking: bool = False,
+ ) -> LogitsProcessorBase:
"""
init logits processor by type and schemata.
Args:
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
+ enable_thinking (bool): Whether to enable thinking step
Returns:
LogitsProcessorBase: Initialized logits processor instance
@@ -222,18 +248,22 @@ class BackendBase:
"""
key_type, schemata = schemata_key
if key_type == "json":
- return self._json_processor(schemata)
+ return self._json_processor(schemata, enable_thinking)
elif key_type == "regex":
- return self._regex_processor(schemata)
+ return self._regex_processor(schemata, enable_thinking)
elif key_type == "grammar":
- return self._grammar_processor(schemata)
+ return self._grammar_processor(schemata, enable_thinking)
elif key_type == "structural_tag":
- return self._structural_tag_processor(schemata)
+ return self._structural_tag_processor(schemata, enable_thinking)
else:
llm_logger.error(f"Unsupported processor type {key_type}.")
return None
- def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
+ def get_logits_processor(
+ self,
+ schemata_key: tuple[str, str],
+ enable_thinking: bool = False,
+ ) -> tuple[LogitsProcessorBase, bool]:
"""
get logits processor by key from cache or create new one.
@@ -247,8 +277,10 @@ class BackendBase:
"""
value = self.cache.get(schemata_key, None)
if value:
- return value.copy(), True
- value = self.executor.submit(self._init_logits_processor, schemata_key)
+ value_copy = value.copy()
+ value_copy.enable_reasoning = enable_thinking
+ return value_copy, True
+ value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
return value, False
def _get_tokenizer_hf(self):
@@ -267,7 +299,6 @@ class BackendBase:
try:
architectures = self.fd_config.model_config.architectures
if not ErnieArchitectures.contains_ernie_arch(architectures):
-
from transformers import AutoTokenizer, PreTrainedTokenizerFast
tokenizer = AutoTokenizer.from_pretrained(
diff --git a/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py b/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py
new file mode 100644
index 000000000..f0ba737a6
--- /dev/null
+++ b/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py
@@ -0,0 +1,118 @@
+"""
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+
+# refer to https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
+
+from typing import List, Optional
+
+import paddle
+
+try:
+ import triton
+ import triton.language as tl
+except ImportError as err:
+ raise ImportError("Triton is not installed") from err
+
+
+@triton.jit
+def apply_token_bitmask_inplace_kernel(
+ logits_ptr,
+ bitmask_ptr,
+ indices_ptr,
+ num_rows,
+ vocab_size,
+ logits_strides,
+ bitmask_strides,
+ NUM_SMS: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """Triton kernel for in-place logits masking using bitwise compression.
+
+ Processes logits tensor in blocks, applying bitmask to restrict vocabulary access.
+ Masked positions are set to -inf to ensure zero probability during sampling.
+
+ Note:
+ - Bitmask uses 32:1 compression (1 bit per vocabulary token)
+ - Optimized for GPU parallel processing with configurable block size
+ """
+ pid = tl.program_id(0)
+ num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
+ for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
+ row_id = work_id // num_blocks
+ block_offset = (work_id % num_blocks) * BLOCK_SIZE
+ batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
+ offsets = block_offset + tl.arange(0, BLOCK_SIZE)
+ bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
+ vocab_mask = offsets < vocab_size
+ packed_bitmask_mask = bitmask_offsets < bitmask_strides
+ packed_bitmask = tl.load(bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask)
+ bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
+ bitmask = bitmask.reshape(BLOCK_SIZE)
+
+ tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
+
+
+def apply_token_bitmask_inplace_triton(
+ logits: paddle.Tensor,
+ bitmask: paddle.Tensor,
+ vocab_size: Optional[int] = None,
+ indices: Optional[List[int]] = None,
+):
+ """Applies vocabulary mask to logits tensor using Triton GPU kernel.
+
+ Args:
+ logits: Input logits tensor of shape [batch_size, vocab_size]
+ bitmask: Compressed mask tensor (int32) where each bit represents a token
+ vocab_size: Optional explicit vocabulary size (defaults to auto-detected)
+ indices: Optional list of batch indices to apply mask to
+
+ Note:
+ Requires CUDA GPU with Triton support
+ Bitmask must be int32 tensor with shape [batch_size, ceil(vocab_size/32)]
+ """
+ NUM_SMS = paddle.device.cuda.get_device_properties().multi_processor_count
+ BLOCK_SIZE = 4096
+
+ assert bitmask.dtype == paddle.int32, "bitmask must be of type int32"
+
+ detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32)
+ if vocab_size is None:
+ vocab_size = detected_vocab_size
+ else:
+ assert (
+ vocab_size <= detected_vocab_size
+ ), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}"
+
+ num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1
+
+ if indices is not None:
+ indices = paddle.to_tensor(indices, dtype=paddle.int32, place=logits.place)
+
+ grid = (NUM_SMS,)
+
+ apply_token_bitmask_inplace_kernel[grid](
+ logits,
+ bitmask,
+ indices,
+ num_rows,
+ vocab_size,
+ logits.shape[-1],
+ bitmask.shape[-1],
+ NUM_SMS,
+ BLOCK_SIZE,
+ num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
+ num_stages=3,
+ )
diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
index 0d448d429..d32d57f3c 100644
--- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
+++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py
@@ -24,7 +24,7 @@ import torch
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request
-from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
+from fastdeploy.model_executor.guided_decoding import (
BackendBase,
BaseChecker,
LogitsProcessorBase,
@@ -57,7 +57,6 @@ class XGrammarProcessor(LogitsProcessorBase):
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
vocab_size (int): Size of the vocabulary
batch_size (int): Batch size for processing
- splitwise_role (str): Role for splitwise processing
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens
@@ -71,13 +70,12 @@ class XGrammarProcessor(LogitsProcessorBase):
override_stop_tokens: Optional[List[int]] = None,
vocab_size: Optional[int] = None,
batch_size: Optional[int] = None,
- splitwise_role: str = "mixed",
+ enable_thinking: bool = False,
):
- super().__init__()
+ super().__init__(enable_reasoning=enable_thinking)
self.max_rollback_tokens = 200
self.vocab_size = vocab_size
self.batch_size = batch_size
- self.splitwise_role = splitwise_role
self.compiled_grammar = compiled_grammar
self.terminate_without_stop_token = terminate_without_stop_token
self.override_stop_tokens = override_stop_tokens
@@ -188,7 +186,6 @@ class XGrammarProcessor(LogitsProcessorBase):
override_stop_tokens=self.override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
- splitwise_role=self.splitwise_role,
)
@@ -203,7 +200,6 @@ class XGrammarBackend(BackendBase):
vocab_size (int): Size of the vocabulary from config
batch_size (int): Maximum batch size from config
any_whitespace (bool): Whether to allow any whitespace in JSON
- splitwise_role (str): Role for splitwise processing
grammar_compiler (GrammarCompiler): Grammar compilation engine
"""
@@ -217,7 +213,6 @@ class XGrammarBackend(BackendBase):
self.batch_size = fd_config.parallel_config.max_num_seqs
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
- self.splitwise_role = fd_config.parallel_config.splitwise_role
try:
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
@@ -230,6 +225,7 @@ class XGrammarBackend(BackendBase):
compiled_grammar: CompiledGrammar,
terminate_without_stop_token: bool = False,
override_stop_tokens: Optional[List[int]] = None,
+ enable_thinking: bool = False,
) -> XGrammarProcessor:
"""
Create a logits processor instance for the given compiled grammar.
@@ -238,6 +234,7 @@ class XGrammarBackend(BackendBase):
compiled_grammar (CompiledGrammar): Compiled grammar rules
terminate_without_stop_token (bool): Whether to terminate without stop token
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
+ enable_thinking (bool): Whether to enable thinking mode
Returns:
XGrammarProcessor: Configured grammar processor instance
@@ -248,15 +245,16 @@ class XGrammarBackend(BackendBase):
override_stop_tokens=override_stop_tokens,
vocab_size=self.vocab_size,
batch_size=self.batch_size,
- splitwise_role=self.splitwise_role,
+ enable_thinking=enable_thinking,
)
- def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
+ def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile JSON schema into a grammar processor.
Args:
schemata (str): JSON schema string to compile
+ enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -266,14 +264,15 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
return None
- return self._create_processor(compiled_grammar)
+ return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
- def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
+ def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile regex pattern into a grammar processor.
Args:
schemata (str): Regex pattern string to compile
+ enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -283,14 +282,15 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
return None
- return self._create_processor(compiled_grammar)
+ return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
- def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
+ def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile grammar (EBNF) into a grammar processor.
Args:
schemata (str): Grammar string in EBNF format
+ enable_thinking (bool): Whether to enable thinking mode
Returns:
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
@@ -300,9 +300,9 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
return None
- return self._create_processor(compiled_grammar)
+ return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
- def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
+ def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
"""
Compile structural tags into a grammar processor.
@@ -327,7 +327,7 @@ class XGrammarBackend(BackendBase):
except Exception as e:
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
return None
- return self._create_processor(compiled_grammar)
+ return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
class XGrammarChecker(BaseChecker):
diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py
index 5aecfa1f9..f8fd1755a 100644
--- a/fastdeploy/model_executor/layers/sample/sampler.py
+++ b/fastdeploy/model_executor/layers/sample/sampler.py
@@ -23,9 +23,7 @@ import paddle.nn.functional as F
from paddle import nn
from fastdeploy.config import FDConfig
-from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
- LogitsProcessorBase,
-)
+from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
@@ -37,6 +35,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
+from fastdeploy.reasoning import ReasoningParser
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -63,6 +62,10 @@ class SamplerProcessor:
self.logits_processor: Dict[int, Optional[Any]] = dict()
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
+ self.reasoning_parser = None
+
+ def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
+ self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
@@ -139,9 +142,14 @@ class SamplerProcessor:
if available_processors is None:
return logits
- indices = list(self.logits_processor.keys())
- mask_idx = [i for i in indices if i not in skip_idx_list]
- return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx)
+ indices = []
+ for idx, processor in self.logits_processor.items():
+ if processor is None or idx in skip_idx_list:
+ continue
+ if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
+ indices.append(idx)
+
+ return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
def _accept_token(self, idx: int, token: int):
"""accept token"""
@@ -151,6 +159,15 @@ class SamplerProcessor:
if self.logits_processor[idx].is_terminated():
return
+ if (
+ self.reasoning_parser is not None
+ and self.logits_processor[idx].enable_reasoning
+ and not self.logits_processor[idx].reasoning_ended
+ ):
+ reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
+ self.logits_processor[idx].reasoning_ended = reasoning_ended
+ return
+
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
@@ -206,12 +223,11 @@ class Sampler(nn.Layer):
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
- def apply_logits_processor(
- self,
- ids: int,
- future: Optional[Any] = None,
- prefill_tokens: List[int] = [],
- ):
+ def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
+ """set reasoning parser"""
+ self.processor.apply_reasoning_parser(reasoning_parser)
+
+ def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
@@ -219,6 +235,10 @@ class Sampler(nn.Layer):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
+ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
+ """post process after running"""
+ self.processor.update_output_tokens(next_tokens, skip_idx_list)
+
def compute_logprobs(
self,
logits: paddle.Tensor,
@@ -307,12 +327,12 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
+ logits = self.processor.apply_token_mask(logits, skip_idx_list)
+
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
- logits = self.processor.apply_token_mask(logits, skip_idx_list)
-
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -347,8 +367,6 @@ class Sampler(nn.Layer):
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
- self.processor.update_output_tokens(next_tokens, skip_idx_list)
-
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
@@ -380,12 +398,15 @@ class SpeculativeSampler(nn.Layer):
"""pre process before running"""
pass
- def apply_logits_processor(
- self,
- ids: int,
- future: Optional[Any] = None,
- prefill_tokens: List[int] = [],
- ):
+ def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
+ """set reasoning parser"""
+ pass
+
+ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
+ """post process after running"""
+ pass
+
+ def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
pass
@@ -480,6 +501,14 @@ class MTPSampler(nn.Layer):
"""apply logits processor to sampler"""
pass
+ def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
+ """set reasoning parser"""
+ pass
+
+ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
+ """post process after running"""
+ pass
+
def forward_cuda(
self,
logits: paddle.Tensor,
diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py
index 6b042fbbe..42848c380 100644
--- a/fastdeploy/worker/gpu_model_runner.py
+++ b/fastdeploy/worker/gpu_model_runner.py
@@ -29,9 +29,9 @@ from fastdeploy.model_executor.graph_optimization.utils import (
profile_run_guard,
sot_warmup_guard,
)
-from fastdeploy.model_executor.guided_decoding import get_guided_backend
-from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
+from fastdeploy.model_executor.guided_decoding import (
LogitsProcessorBase,
+ get_guided_backend,
)
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -97,10 +97,6 @@ class GPUModelRunner(ModelRunnerBase):
self.enable_logprob = fd_config.model_config.enable_logprob
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
- self.guided_backend = None
- if self.fd_config.parallel_config.guided_decoding_backend != "off":
- self.guided_backend = get_guided_backend(fd_config=self.fd_config)
-
# VL model config:
if self.enable_mm:
if "ernie" in self.fd_config.model_config.model_type:
@@ -129,6 +125,11 @@ class GPUModelRunner(ModelRunnerBase):
else:
self.sampler = SpeculativeSampler(fd_config)
+ self.guided_backend = None
+ if self.fd_config.parallel_config.guided_decoding_backend != "off":
+ self.guided_backend = get_guided_backend(fd_config=self.fd_config)
+ self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser())
+
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
@@ -206,7 +207,16 @@ class GPUModelRunner(ModelRunnerBase):
elif request.structural_tag is not None:
schemata_key = ("structural_tag", request.structural_tag)
- return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
+ enable_thinking = request.get("enable_thinking", True)
+ enable_thinking = enable_thinking if enable_thinking is not None else True
+
+ return (
+ self.guided_backend.get_logits_processor(
+ schemata_key=schemata_key,
+ enable_thinking=enable_thinking,
+ ),
+ schemata_key,
+ )
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
@@ -1336,10 +1346,10 @@ class GPUModelRunner(ModelRunnerBase):
Returns:
A list of indices corresponding to the requests that need to be skipped.
"""
- skip_idx_list = []
- if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
- return skip_idx_list
+ if not self.cache_config.enable_chunked_prefill or self.guided_backend is None or model_forward_batch is None:
+ return []
+ skip_idx_list = []
for task in model_forward_batch:
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
continue
@@ -1505,6 +1515,8 @@ class GPUModelRunner(ModelRunnerBase):
speculative_decoding=self.speculative_decoding,
skip_save_output=skip_save_output,
)
+ if self.guided_backend is not None and sampler_output is not None:
+ self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
# 6. Speculative decode
if self.speculative_decoding:
@@ -1538,7 +1550,7 @@ class GPUModelRunner(ModelRunnerBase):
"""
Add cache for guided decoding.
"""
- if self.guided_backend is None:
+ if self.guided_backend is None or model_forward_batch is None:
return
for request in model_forward_batch:
diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py
index a57195391..28b883662 100644
--- a/fastdeploy/worker/worker_process.py
+++ b/fastdeploy/worker/worker_process.py
@@ -590,6 +590,12 @@ def parse_args():
action="store_true",
help="Enable output of token-level log probabilities.",
)
+ parser.add_argument(
+ "--reasoning_parser",
+ type=str,
+ default=None,
+ help="Flag specifies the reasoning parser to use for extracting reasoning content from the model output",
+ )
parser.add_argument(
"--early_stop_config",
type=json.loads,
diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh
index 67b06736e..ab36dac96 100644
--- a/scripts/run_pre_ce.sh
+++ b/scripts/run_pre_ce.sh
@@ -7,6 +7,7 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p
python -m pip install -r requirements.txt
python -m pip install jsonschema aistudio_sdk==0.3.5
+python -m pip install xgrammar==0.1.19 torch==2.6.0
failed_files=()
run_path="$DIR/../tests/ci_use/"
diff --git a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py
index a5688c866..8cba35eb1 100644
--- a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py
+++ b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import os
import re
import shutil
@@ -110,6 +111,8 @@ def setup_and_run_server():
"--use-cudagraph",
"--graph-optimization-config",
'{"cudagraph_capture_sizes": [1]}',
+ "--guided-decoding-backend",
+ "auto",
]
# Start subprocess in new process group
@@ -1142,3 +1145,336 @@ def test_profile_reset_block_num():
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
)
+
+
+def streaming_chat_base(openai_client, chat_param):
+ """
+ Test streaming chat base functionality with the local service
+ """
+ assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
+ assert "messages" in chat_param, f"{chat_param} should contain messages"
+
+ response = openai_client.chat.completions.create(
+ model="default",
+ stream=True,
+ **chat_param,
+ )
+
+ output = []
+ for chunk in response:
+ if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
+ output.append(chunk.choices[0].delta.content)
+ assert len(output) > 2
+ return "".join(output)
+
+
+def non_streaming_chat_base(openai_client, chat_param):
+ """
+ Test non streaming chat base functionality with the local service
+ """
+ assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
+ assert "messages" in chat_param, f"{chat_param} should contain messages"
+
+ response = openai_client.chat.completions.create(
+ model="default",
+ stream=False,
+ **chat_param,
+ )
+
+ assert hasattr(response, "choices")
+ assert len(response.choices) > 0
+ assert hasattr(response.choices[0], "message")
+ assert hasattr(response.choices[0].message, "content")
+ return response.choices[0].message.content
+
+
+def test_structured_outputs_json_schema(openai_client):
+ """
+ Test structured outputs json_schema functionality with the local service
+ """
+ chat_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ }
+
+ # json_object
+ json_chat_param = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "Generate a JSON object containing: names of China's Four Great Inventions, their dynasties of origin, and brief descriptions (each under 50 characters)",
+ }
+ ],
+ "response_format": {"type": "json_object"},
+ }
+ json_chat_param.update(chat_param)
+
+ response = streaming_chat_base(openai_client, json_chat_param)
+ try:
+ json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema streaming response: {response} is not a valid json"
+
+ response = non_streaming_chat_base(openai_client, json_chat_param)
+ try:
+ json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
+
+ # json_schema
+ from enum import Enum
+
+ from pydantic import BaseModel
+
+ class BookType(str, Enum):
+ romance = "Romance"
+ historical = "Historical"
+ adventure = "Adventure"
+ mystery = "Mystery"
+ dystopian = "Dystopian"
+
+ class BookDescription(BaseModel):
+ author: str
+ title: str
+ genre: BookType
+
+ json_schema_param = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "Generate a JSON describing a literary work, including author, title and book type.",
+ }
+ ],
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()},
+ },
+ }
+ json_schema_param.update(chat_param)
+ response = streaming_chat_base(openai_client, json_schema_param)
+ try:
+ json_schema_response = json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema streaming response: {response} is not a valid json"
+ assert (
+ "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
+ ), f"json_schema streaming response: {response} is not a valid book-description"
+ assert json_schema_response["genre"] in {
+ genre.value for genre in BookType
+ }, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type"
+
+ response = non_streaming_chat_base(openai_client, json_schema_param)
+ try:
+ json_schema_response = json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
+ assert (
+ "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
+ ), f"json_schema non_streaming response: {response} is not a valid book-description"
+ assert json_schema_response["genre"] in {
+ genre.value for genre in BookType
+ }, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type"
+
+
+def test_structured_outputs_structural_tag(openai_client):
+ """
+ Test structured outputs structural_tag functionality with the local service
+ """
+ content_str = """
+ You have the following function available:
+
+ {
+ "name": "get_current_date",
+ "description": "Get current date and time for given timezone",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "timezone": {
+ "type": "string",
+ "description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
+ }
+ },
+ "required": ["timezone"],
+ }
+ }
+
+ If you choose to call only this function, reply in this format:
+ <{start_tag}={function_name}>{parameters}{end_tag}
+ where:
+
+ start_tag => ` JSON dictionary with parameter names as keys
+ end_tag => ``
+
+ Example:
+ {"param": "value"}
+
+ Note:
+ - Function call must follow specified format
+ - Required parameters must be specified
+ - Only one function can be called at a time
+ - Place entire function call response on a single line
+
+ You are an AI assistant. Answer the following question.
+ """
+
+ structural_tag_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ "messages": [
+ {
+ "role": "system",
+ "content": content_str,
+ },
+ {
+ "role": "user",
+ "content": "You're traveling to Shanghai today",
+ },
+ ],
+ "response_format": {
+ "type": "structural_tag",
+ "structures": [
+ {
+ "begin": "",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "timezone": {
+ "type": "string",
+ "description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
+ }
+ },
+ "required": ["timezone"],
+ },
+ "end": "",
+ }
+ ],
+ "triggers": ["" text ""
+
+ style_attribute ::= " style=" dq style_value dq
+
+ style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?)
+
+ font_style ::= "font-family: '" font_name "'"
+
+ font_weight ::= "font-weight: " weight_value
+
+ font_name ::= "Arial" | "Times New Roman" | "Courier New"
+
+ weight_value ::= "normal" | "bold"
+
+ text ::= [A-Za-z0-9 ]+
+
+ dq ::= ["]
+ """
+
+ grammar_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ "messages": [
+ {
+ "role": "user",
+ "content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot",
+ }
+ ],
+ "extra_body": {"guided_grammar": html_h1_grammar},
+ }
+
+ import re
+
+ pattern = r'^[A-Za-z0-9 ]+
$'
+ response = streaming_chat_base(openai_client, grammar_param)
+ assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected"
+ response = non_streaming_chat_base(openai_client, grammar_param)
+ assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected"
diff --git a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py
index 6eb78345d..fed911861 100644
--- a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py
+++ b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py
@@ -119,6 +119,8 @@ def setup_and_run_server():
"wint4",
"--reasoning-parser",
"ernie-45-vl",
+ "--guided-decoding-backend",
+ "auto",
]
# Start subprocess in new process group
@@ -540,6 +542,348 @@ def test_chat_with_thinking(openai_client, capsys):
assert reasoning_tokens <= reasoning_max_tokens
+def streaming_chat_base(openai_client, chat_param):
+ """
+ Test streaming chat base functionality with the local service
+ """
+ assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
+ assert "messages" in chat_param, f"{chat_param} should contain messages"
+
+ response = openai_client.chat.completions.create(
+ model="default",
+ stream=True,
+ **chat_param,
+ )
+
+ output = []
+ for chunk in response:
+ if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
+ output.append(chunk.choices[0].delta.content)
+ assert len(output) > 2
+ return "".join(output)
+
+
+def non_streaming_chat_base(openai_client, chat_param):
+ """
+ Test non streaming chat base functionality with the local service
+ """
+ assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
+ assert "messages" in chat_param, f"{chat_param} should contain messages"
+
+ response = openai_client.chat.completions.create(
+ model="default",
+ stream=False,
+ **chat_param,
+ )
+
+ assert hasattr(response, "choices")
+ assert len(response.choices) > 0
+ assert hasattr(response.choices[0], "message")
+ assert hasattr(response.choices[0].message, "content")
+ return response.choices[0].message.content
+
+
+def test_structured_outputs_json_schema(openai_client):
+ """
+ Test structured outputs json_schema functionality with the local service
+ """
+ chat_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ }
+
+ # json_object
+ json_chat_param = {
+ "messages": [
+ {"role": "system", "content": "You are a helpful AI assistant."},
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
+ "detail": "high",
+ },
+ },
+ {"type": "text", "text": "请描述图片内容,使用json格式输出结果"},
+ ],
+ },
+ ],
+ "response_format": {"type": "json_object"},
+ }
+ json_chat_param.update(chat_param)
+
+ outputs = []
+ outputs.append(streaming_chat_base(openai_client, json_chat_param))
+ outputs.append(non_streaming_chat_base(openai_client, json_chat_param))
+
+ json_chat_param["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
+ outputs.append(streaming_chat_base(openai_client, json_chat_param))
+ outputs.append(non_streaming_chat_base(openai_client, json_chat_param))
+
+ for response in outputs:
+ try:
+ json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_object response: {response} is not a valid json"
+
+ # json_schema
+ from enum import Enum
+
+ from pydantic import BaseModel
+
+ class BookType(str, Enum):
+ romance = "Romance"
+ historical = "Historical"
+ adventure = "Adventure"
+ mystery = "Mystery"
+ dystopian = "Dystopian"
+
+ class BookDescription(BaseModel):
+ author: str
+ title: str
+ genre: BookType
+
+ json_schema_param = {
+ "messages": [
+ {
+ "role": "user",
+ "content": "Generate a JSON describing a literary work, including author, title and book type.",
+ }
+ ],
+ "response_format": {
+ "type": "json_schema",
+ "json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()},
+ },
+ }
+ json_schema_param.update(chat_param)
+ response = streaming_chat_base(openai_client, json_schema_param)
+ try:
+ json_schema_response = json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema streaming response: {response} is not a valid json"
+ assert (
+ "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
+ ), f"json_schema streaming response: {response} is not a valid book-description"
+ assert json_schema_response["genre"] in {
+ genre.value for genre in BookType
+ }, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type"
+
+ response = non_streaming_chat_base(openai_client, json_schema_param)
+ try:
+ json_schema_response = json.loads(response)
+ is_valid = True
+ except ValueError:
+ is_valid = False
+
+ assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
+ assert (
+ "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
+ ), f"json_schema non_streaming response: {response} is not a valid book-description"
+ assert json_schema_response["genre"] in {
+ genre.value for genre in BookType
+ }, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type"
+
+
+def test_structured_outputs_structural_tag(openai_client):
+ """
+ Test structured outputs structural_tag functionality with the local service
+ """
+ content_str = """
+ You have the following function available:
+
+ {
+ "name": "get_current_date",
+ "description": "Get current date and time for given timezone",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "timezone": {
+ "type": "string",
+ "description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
+ }
+ },
+ "required": ["timezone"],
+ }
+ }
+
+ If you choose to call only this function, reply in this format:
+ <{start_tag}={function_name}>{parameters}{end_tag}
+ where:
+
+ start_tag => ` JSON dictionary with parameter names as keys
+ end_tag => ``
+
+ Example:
+ {"param": "value"}
+
+ Note:
+ - Function call must follow specified format
+ - Required parameters must be specified
+ - Only one function can be called at a time
+ - Place entire function call response on a single line
+
+ You are an AI assistant. Answer the following question.
+ """
+
+ structural_tag_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ "messages": [
+ {
+ "role": "system",
+ "content": content_str,
+ },
+ {
+ "role": "user",
+ "content": "You're traveling to Shanghai today",
+ },
+ ],
+ "response_format": {
+ "type": "structural_tag",
+ "structures": [
+ {
+ "begin": "",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "timezone": {
+ "type": "string",
+ "description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
+ }
+ },
+ "required": ["timezone"],
+ },
+ "end": "",
+ }
+ ],
+ "triggers": ["" text ""
+
+ style_attribute ::= " style=" dq style_value dq
+
+ style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?)
+
+ font_style ::= "font-family: '" font_name "'"
+
+ font_weight ::= "font-weight: " weight_value
+
+ font_name ::= "Arial" | "Times New Roman" | "Courier New"
+
+ weight_value ::= "normal" | "bold"
+
+ text ::= [A-Za-z0-9 ]+
+
+ dq ::= ["]
+ """
+
+ grammar_param = {
+ "temperature": 1,
+ "max_tokens": 1024,
+ "messages": [
+ {
+ "role": "user",
+ "content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot",
+ }
+ ],
+ "extra_body": {"guided_grammar": html_h1_grammar},
+ }
+
+ import re
+
+ pattern = r'^[A-Za-z0-9 ]+
$'
+ response = streaming_chat_base(openai_client, grammar_param)
+ assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected"
+ response = non_streaming_chat_base(openai_client, grammar_param)
+ assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected"
+
+
def test_profile_reset_block_num():
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
log_file = "./log/config.log"