From 190846554232bfc3bec9cdff728e41437621bde1 Mon Sep 17 00:00:00 2001 From: kevin Date: Tue, 2 Sep 2025 16:21:09 +0800 Subject: [PATCH] [Feature] mm and thinking model support structred output (#2749) * mm support structured output * update code * update code * update format * update code * update code * add enable_thinking default * update code * add structured_outputs test case * add ci install xgrammar * add ci timeout time * update test for structured_outputs * update code * add error traceback info * update error msg * update structred output code * update code * update code * update config * update torch version --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> --- docs/features/structured_outputs.md | 62 ++++ docs/zh/features/structured_outputs.md | 64 ++++ fastdeploy/config.py | 15 +- fastdeploy/engine/engine.py | 28 +- fastdeploy/engine/request.py | 12 +- fastdeploy/engine/sampling_params.py | 51 +++ fastdeploy/entrypoints/llm.py | 3 + .../guided_decoding/__init__.py | 7 +- .../guided_decoding/base_guided_decoding.py | 63 +++- .../kernels/xgrammar_apply_token_bitmask.py | 118 ++++++ .../guided_decoding/xgrammar_backend.py | 34 +- .../model_executor/layers/sample/sampler.py | 73 ++-- fastdeploy/worker/gpu_model_runner.py | 34 +- fastdeploy/worker/worker_process.py | 6 + scripts/run_pre_ce.sh | 1 + tests/ci_use/EB_Lite/test_EB_Lite_serving.py | 336 +++++++++++++++++ .../EB_VL_Lite/test_EB_VL_Lite_serving.py | 344 ++++++++++++++++++ 17 files changed, 1168 insertions(+), 83 deletions(-) create mode 100644 fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 40e177c1c..f7ee424cb 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong Address: No.1 Century Avenue, Pudong New Area, Shanghai Height: 468 ``` + +### Offline Inference + +Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference: + +```python +json: Optional[Union[str, dict]] = None +regex: Optional[str] = None +choice: Optional[List[str]] = None +grammar: Optional[str] = None +json_object: Optional[bool] = None +structural_tag: Optional[str] = None +``` + +The following example demonstrates how to use offline inference to generate a structured json: + +```python +from fastdeploy import LLM, SamplingParams +from fastdeploy.engine.sampling_params import GuidedDecodingParams +from pydantic import BaseModel +from enum import Enum + +class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + +class BookDescription(BaseModel): + author: str + title: str + genre: BookType + +# Constrained decoding parameters +guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema()) + +# Sampling parameters +sampling_params = SamplingParams( + top_p=0.95, + max_tokens=6400, + guided_decoding=guided_decoding_params, +) + +# Load model +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto") + +outputs = llm.generate( + prompts="Generate a JSON describing a literary work, including author, title and book type.", + sampling_params=sampling_params, +) + +# Output results +for output in outputs: + print(output.outputs.text) +``` + +Output: + +``` +{"author": "George Orwell", "title": "1984", "genre": "Dystopian"} +``` diff --git a/docs/zh/features/structured_outputs.md b/docs/zh/features/structured_outputs.md index ce33f1232..cafda804c 100644 --- a/docs/zh/features/structured_outputs.md +++ b/docs/zh/features/structured_outputs.md @@ -330,3 +330,67 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪 地址: 上海市浦东新区世纪大道1号 高度: 468 ``` + +### 离线推理 + +离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理: + +```python +json: Optional[Union[str, dict]] = None +regex: Optional[str] = None +choice: Optional[List[str]] = None +grammar: Optional[str] = None +json_object: Optional[bool] = None +structural_tag: Optional[str] = None +``` + +以下示例展示了如何使用离线推理生成一个结构化的 json : + +```python + +from fastdeploy import LLM, SamplingParams +from fastdeploy.engine.sampling_params import GuidedDecodingParams +from pydantic import BaseModel +from enum import Enum + +class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + +class BookDescription(BaseModel): + author: str + title: str + genre: BookType + +# Constrained decoding parameters +guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema()) + +# Sampling parameters +sampling_params = SamplingParams( + top_p=0.95, + max_tokens=6400, + guided_decoding=guided_decoding_params, +) + +# Load model +llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto") + +outputs = llm.generate( + prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。", + sampling_params=sampling_params, +) + +# Output results +for output in outputs: + print(output.outputs.text) + +``` + +输出 + +``` +{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"} +``` diff --git a/fastdeploy/config.py b/fastdeploy/config.py index c247b7880..c52b2530b 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -127,12 +127,13 @@ class ModelConfig: self.redundant_experts_num = 0 self.seed = 0 self.quantization = None + self.reasoning_parser = None self.pad_token_id: int = -1 self.eos_tokens_lens: int = 2 self.lm_head_fp32: bool = False self.model_format = "auto" for key, value in args.items(): - if hasattr(self, key): + if hasattr(self, key) and value != "None": setattr(self, key, value) assert self.model != "" @@ -1249,7 +1250,8 @@ class FDConfig: self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size) if self.guided_decoding_backend == "auto": - if self.model_config.enable_mm: + if current_platform.is_xpu() or self.speculative_config.method is not None: + logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.") self.guided_decoding_backend = "off" else: self.guided_decoding_backend = "xgrammar" @@ -1319,12 +1321,10 @@ class FDConfig: ], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}." if self.guided_decoding_backend != "off": - # TODO: mm support guided_decoding - assert ( - self.model_config.enable_mm is False - ), "Multimodal model currently do not support guided_decoding" - # TODO: speculative decoding support guided_decoding + assert ( + self.speculative_config.method is None + ), "speculative decoding currently do not support guided_decoding" # TODO: xpu support guided_decoding assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding" @@ -1335,6 +1335,7 @@ class FDConfig: raise Exception( f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}" ) + if self.scheduler_config is not None: self.scheduler_config.check() diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index c508f4ee5..a71a2df61 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -178,6 +178,22 @@ class LLMEngine: # _insert_task_to_worker moved to CommonEngine + def _has_guided_input(self, request): + """ + Check if the request has any guided input. + """ + return any( + x is not None + for x in ( + request.guided_json, + request.guided_regex, + request.guided_choice, + request.structural_tag, + request.guided_grammar, + request.guided_json_object, + ) + ) + def add_requests(self, task, sampling_params=None, **kwargs): """ Add a new request to the queue. @@ -249,8 +265,15 @@ class LLMEngine: llm_logger.error(error_msg) raise EngineError(error_msg, error_code=400) - if self.engine.guided_decoding_checker is not None: - request, err_msg = self.engine.guided_decoding_checker.schema_format(request) + if self._has_guided_input(request): + err_msg = None + if self.guided_decoding_checker is None: + err_msg = ( + "guided_backend is None, use --guided-decoding-backend to specify the backend at server startup." + ) + else: + request, err_msg = self.guided_decoding_checker.schema_format(request) + if err_msg is not None: llm_logger.error(err_msg) raise EngineError(err_msg, error_code=400) @@ -469,6 +492,7 @@ class LLMEngine: f" --guided_decoding_backend {self.cfg.guided_decoding_backend}" f" --load_strategy {self.cfg.load_config.load_strategy}" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" + f" --reasoning_parser {self.cfg.reasoning_parser}" f" --load_choices {self.cfg.load_config.load_choices}" f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --ips {ips}" diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 04a2276af..c1431c42f 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -263,13 +263,11 @@ class Request: setattr(self, key, value) def __repr__(self) -> str: - return ( - f"Request(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"draft_token_ids={self.draft_token_ids}, " - f"sampling_params={self.sampling_params})" - ) + non_none_fields = [] + for attr, value in vars(self).items(): + if value is not None and not attr.startswith("_"): + non_none_fields.append(f"{attr}={value!r}") + return f"Request({', '.join(non_none_fields)})" @dataclass(slots=True) diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 423434857..73231b3be 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -100,6 +100,7 @@ class SamplingParams: temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False bad_words: Optional[List[str]] = None + guided_decoding: Optional[GuidedDecodingParams] = None bad_words_token_ids: Optional[List[int]] = None @classmethod @@ -132,6 +133,7 @@ class SamplingParams: min_tokens=1, logprobs=None, bad_words=None, + guided_decoding=None, bad_words_token_ids=None, ) -> SamplingParams: """Create instance from command line arguments""" @@ -153,6 +155,7 @@ class SamplingParams: min_tokens=min_tokens, logprobs=logprobs, bad_words=bad_words, + guided_decoding=guided_decoding, bad_words_token_ids=bad_words_token_ids, ) @@ -217,3 +220,51 @@ class BeamSearchParams: temperature: float = 0.0 length_penalty: float = 1.0 include_stop_str_in_output: bool = False + + +@dataclass +class GuidedDecodingParams: + """Guided decoding parameters for text generation.""" + + json: Optional[Union[str, dict]] = None + regex: Optional[str] = None + choice: Optional[List[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + structural_tag: Optional[str] = None + + def to_dict(self): + """convert to dict""" + key_dict = { + "guided_json": self.json, + "guided_regex": self.regex, + "guided_choice": self.choice, + "guided_grammar": self.grammar, + "structural_tag": self.structural_tag, + "guided_json_object": self.json_object, + } + + guided_dict = {} + for key, value in key_dict.items(): + if value is not None: + guided_dict[key] = value + return guided_dict + + def __post_init__(self): + """Verify the arguments.""" + guided_count = sum( + [ + self.json is not None, + self.regex is not None, + self.choice is not None, + self.grammar is not None, + self.json_object is not None, + self.structural_tag is not None, + ] + ) + + if guided_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')." + ) diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index f9537e557..d69068b6f 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -295,6 +295,9 @@ class LLM: current_sampling_params = sampling_params[i] else: current_sampling_params = sampling_params + if current_sampling_params.guided_decoding is not None: + guided_decoding_dict = current_sampling_params.guided_decoding.to_dict() + tasks.update(guided_decoding_dict) self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs) return req_ids diff --git a/fastdeploy/model_executor/guided_decoding/__init__.py b/fastdeploy/model_executor/guided_decoding/__init__.py index d6ee61199..9336f4a04 100644 --- a/fastdeploy/model_executor/guided_decoding/__init__.py +++ b/fastdeploy/model_executor/guided_decoding/__init__.py @@ -15,8 +15,13 @@ """ # from fastdeploy.config import FDConfig +from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( + BackendBase, + BaseChecker, + LogitsProcessorBase, +) -__all__ = ["get_guided_backend", "schema_checker"] +__all__ = ["get_guided_backend", "schema_checker", "LogitsProcessorBase", "BackendBase", "BaseChecker"] def get_guided_backend( diff --git a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py index ea18fbe8b..b9a879e32 100644 --- a/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py +++ b/fastdeploy/model_executor/guided_decoding/base_guided_decoding.py @@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor from fastdeploy.config import ErnieArchitectures, FDConfig from fastdeploy.engine.request import Request +from fastdeploy.reasoning import ReasoningParserManager from fastdeploy.utils import llm_logger @@ -35,8 +36,9 @@ class LogitsProcessorBase: None (all state should be managed by subclasses) """ - def __init__(self): - pass + def __init__(self, enable_reasoning): + self.reasoning_ended = False + self.enable_reasoning = enable_reasoning def fill_token_bitmask(self, token_bitmask, idx): """ @@ -137,8 +139,14 @@ class BackendBase: self.fd_config = fd_config self.executor = ThreadPoolExecutor() self.max_cache_size = 2048 + self.reasoning_parser = None self.hf_tokenizer = self._get_tokenizer_hf() + if self.fd_config.model_config.reasoning_parser: + reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser( + self.fd_config.model_config.reasoning_parser + ) + self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer) def _create_processor(self): """ @@ -149,70 +157,88 @@ class BackendBase: """ raise NotImplementedError - def _json_processor(self, schemata): + def _json_processor(self, schemata, enable_thinking=False): """ Process JSON schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError - def _regex_processor(self, schemata): + def _regex_processor(self, schemata, enable_thinking=False): """ Process regular expression schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError - def _grammar_processor(self, schemata): + def _grammar_processor(self, schemata, enable_thinking=False): """ Process grammar schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError - def _structural_tag_processor(self, schemata): + def _structural_tag_processor(self, schemata, enable_thinking=False): """ Process structural tag schemata. Args: schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. Raises: NotImplementedError: This method should be implemented in subclasses. """ raise NotImplementedError - def _unsupported_processor_type(self, key_type, schemata): + def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False): """ Process unsupported type. Args: key_type (str): The key type string. schemata (str): The schemata string. + enable_thinking (bool): Whether to enable thinking mode. """ raise Exception(f"Unsupported processor type {key_type}.") - def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase: + def get_reasoning_parser(self): + """ + Get reasoning parser object. + Returns: + ReasoningParser: Reasoning parser object or None + """ + return self.reasoning_parser + + def _init_logits_processor( + self, + schemata_key: tuple[str, str], + enable_thinking: bool = False, + ) -> LogitsProcessorBase: """ init logits processor by type and schemata. Args: schemata_key (tuple[str, str]): Tuple containing processor type and schema string + enable_thinking (bool): Whether to enable thinking step Returns: LogitsProcessorBase: Initialized logits processor instance @@ -222,18 +248,22 @@ class BackendBase: """ key_type, schemata = schemata_key if key_type == "json": - return self._json_processor(schemata) + return self._json_processor(schemata, enable_thinking) elif key_type == "regex": - return self._regex_processor(schemata) + return self._regex_processor(schemata, enable_thinking) elif key_type == "grammar": - return self._grammar_processor(schemata) + return self._grammar_processor(schemata, enable_thinking) elif key_type == "structural_tag": - return self._structural_tag_processor(schemata) + return self._structural_tag_processor(schemata, enable_thinking) else: llm_logger.error(f"Unsupported processor type {key_type}.") return None - def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]: + def get_logits_processor( + self, + schemata_key: tuple[str, str], + enable_thinking: bool = False, + ) -> tuple[LogitsProcessorBase, bool]: """ get logits processor by key from cache or create new one. @@ -247,8 +277,10 @@ class BackendBase: """ value = self.cache.get(schemata_key, None) if value: - return value.copy(), True - value = self.executor.submit(self._init_logits_processor, schemata_key) + value_copy = value.copy() + value_copy.enable_reasoning = enable_thinking + return value_copy, True + value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking) return value, False def _get_tokenizer_hf(self): @@ -267,7 +299,6 @@ class BackendBase: try: architectures = self.fd_config.model_config.architectures if not ErnieArchitectures.contains_ernie_arch(architectures): - from transformers import AutoTokenizer, PreTrainedTokenizerFast tokenizer = AutoTokenizer.from_pretrained( diff --git a/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py b/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py new file mode 100644 index 000000000..f0ba737a6 --- /dev/null +++ b/fastdeploy/model_executor/guided_decoding/kernels/xgrammar_apply_token_bitmask.py @@ -0,0 +1,118 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# refer to https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py + +from typing import List, Optional + +import paddle + +try: + import triton + import triton.language as tl +except ImportError as err: + raise ImportError("Triton is not installed") from err + + +@triton.jit +def apply_token_bitmask_inplace_kernel( + logits_ptr, + bitmask_ptr, + indices_ptr, + num_rows, + vocab_size, + logits_strides, + bitmask_strides, + NUM_SMS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Triton kernel for in-place logits masking using bitwise compression. + + Processes logits tensor in blocks, applying bitmask to restrict vocabulary access. + Masked positions are set to -inf to ensure zero probability during sampling. + + Note: + - Bitmask uses 32:1 compression (1 bit per vocabulary token) + - Optimized for GPU parallel processing with configurable block size + """ + pid = tl.program_id(0) + num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE) + for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS): + row_id = work_id // num_blocks + block_offset = (work_id % num_blocks) * BLOCK_SIZE + batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id) + offsets = block_offset + tl.arange(0, BLOCK_SIZE) + bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32) + vocab_mask = offsets < vocab_size + packed_bitmask_mask = bitmask_offsets < bitmask_strides + packed_bitmask = tl.load(bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask) + bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0 + bitmask = bitmask.reshape(BLOCK_SIZE) + + tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask) + + +def apply_token_bitmask_inplace_triton( + logits: paddle.Tensor, + bitmask: paddle.Tensor, + vocab_size: Optional[int] = None, + indices: Optional[List[int]] = None, +): + """Applies vocabulary mask to logits tensor using Triton GPU kernel. + + Args: + logits: Input logits tensor of shape [batch_size, vocab_size] + bitmask: Compressed mask tensor (int32) where each bit represents a token + vocab_size: Optional explicit vocabulary size (defaults to auto-detected) + indices: Optional list of batch indices to apply mask to + + Note: + Requires CUDA GPU with Triton support + Bitmask must be int32 tensor with shape [batch_size, ceil(vocab_size/32)] + """ + NUM_SMS = paddle.device.cuda.get_device_properties().multi_processor_count + BLOCK_SIZE = 4096 + + assert bitmask.dtype == paddle.int32, "bitmask must be of type int32" + + detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32) + if vocab_size is None: + vocab_size = detected_vocab_size + else: + assert ( + vocab_size <= detected_vocab_size + ), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}" + + num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1 + + if indices is not None: + indices = paddle.to_tensor(indices, dtype=paddle.int32, place=logits.place) + + grid = (NUM_SMS,) + + apply_token_bitmask_inplace_kernel[grid]( + logits, + bitmask, + indices, + num_rows, + vocab_size, + logits.shape[-1], + bitmask.shape[-1], + NUM_SMS, + BLOCK_SIZE, + num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()), + num_stages=3, + ) diff --git a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py index 0d448d429..d32d57f3c 100644 --- a/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py +++ b/fastdeploy/model_executor/guided_decoding/xgrammar_backend.py @@ -24,7 +24,7 @@ import torch from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( +from fastdeploy.model_executor.guided_decoding import ( BackendBase, BaseChecker, LogitsProcessorBase, @@ -57,7 +57,6 @@ class XGrammarProcessor(LogitsProcessorBase): max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch vocab_size (int): Size of the vocabulary batch_size (int): Batch size for processing - splitwise_role (str): Role for splitwise processing compiled_grammar (CompiledGrammar): Compiled grammar rules terminate_without_stop_token (bool): Whether to terminate without stop token override_stop_tokens (Optional[List[int]]): Custom stop tokens @@ -71,13 +70,12 @@ class XGrammarProcessor(LogitsProcessorBase): override_stop_tokens: Optional[List[int]] = None, vocab_size: Optional[int] = None, batch_size: Optional[int] = None, - splitwise_role: str = "mixed", + enable_thinking: bool = False, ): - super().__init__() + super().__init__(enable_reasoning=enable_thinking) self.max_rollback_tokens = 200 self.vocab_size = vocab_size self.batch_size = batch_size - self.splitwise_role = splitwise_role self.compiled_grammar = compiled_grammar self.terminate_without_stop_token = terminate_without_stop_token self.override_stop_tokens = override_stop_tokens @@ -188,7 +186,6 @@ class XGrammarProcessor(LogitsProcessorBase): override_stop_tokens=self.override_stop_tokens, vocab_size=self.vocab_size, batch_size=self.batch_size, - splitwise_role=self.splitwise_role, ) @@ -203,7 +200,6 @@ class XGrammarBackend(BackendBase): vocab_size (int): Size of the vocabulary from config batch_size (int): Maximum batch size from config any_whitespace (bool): Whether to allow any whitespace in JSON - splitwise_role (str): Role for splitwise processing grammar_compiler (GrammarCompiler): Grammar compilation engine """ @@ -217,7 +213,6 @@ class XGrammarBackend(BackendBase): self.batch_size = fd_config.parallel_config.max_num_seqs self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace - self.splitwise_role = fd_config.parallel_config.splitwise_role try: tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size) @@ -230,6 +225,7 @@ class XGrammarBackend(BackendBase): compiled_grammar: CompiledGrammar, terminate_without_stop_token: bool = False, override_stop_tokens: Optional[List[int]] = None, + enable_thinking: bool = False, ) -> XGrammarProcessor: """ Create a logits processor instance for the given compiled grammar. @@ -238,6 +234,7 @@ class XGrammarBackend(BackendBase): compiled_grammar (CompiledGrammar): Compiled grammar rules terminate_without_stop_token (bool): Whether to terminate without stop token override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults + enable_thinking (bool): Whether to enable thinking mode Returns: XGrammarProcessor: Configured grammar processor instance @@ -248,15 +245,16 @@ class XGrammarBackend(BackendBase): override_stop_tokens=override_stop_tokens, vocab_size=self.vocab_size, batch_size=self.batch_size, - splitwise_role=self.splitwise_role, + enable_thinking=enable_thinking, ) - def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: """ Compile JSON schema into a grammar processor. Args: schemata (str): JSON schema string to compile + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -266,14 +264,15 @@ class XGrammarBackend(BackendBase): except Exception as e: llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) - def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: """ Compile regex pattern into a grammar processor. Args: schemata (str): Regex pattern string to compile + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -283,14 +282,15 @@ class XGrammarBackend(BackendBase): except Exception as e: llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) - def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: """ Compile grammar (EBNF) into a grammar processor. Args: schemata (str): Grammar string in EBNF format + enable_thinking (bool): Whether to enable thinking mode Returns: Optional[XGrammarProcessor]: Configured processor if successful, None on failure @@ -300,9 +300,9 @@ class XGrammarBackend(BackendBase): except Exception as e: llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) - def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]: + def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]: """ Compile structural tags into a grammar processor. @@ -327,7 +327,7 @@ class XGrammarBackend(BackendBase): except Exception as e: llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}") return None - return self._create_processor(compiled_grammar) + return self._create_processor(compiled_grammar, enable_thinking=enable_thinking) class XGrammarChecker(BaseChecker): diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 5aecfa1f9..f8fd1755a 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -23,9 +23,7 @@ import paddle.nn.functional as F from paddle import nn from fastdeploy.config import FDConfig -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( - LogitsProcessorBase, -) +from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( get_early_stopper_cls_from_stragegy, ) @@ -37,6 +35,7 @@ from fastdeploy.model_executor.layers.sample.ops import ( top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform +from fastdeploy.reasoning import ReasoningParser from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -63,6 +62,10 @@ class SamplerProcessor: self.logits_processor: Dict[int, Optional[Any]] = dict() self.executor = ThreadPoolExecutor() self.logits_lock = threading.Lock() + self.reasoning_parser = None + + def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + self.reasoning_parser = reasoning_parser def add_logits_processor( self, @@ -139,9 +142,14 @@ class SamplerProcessor: if available_processors is None: return logits - indices = list(self.logits_processor.keys()) - mask_idx = [i for i in indices if i not in skip_idx_list] - return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx) + indices = [] + for idx, processor in self.logits_processor.items(): + if processor is None or idx in skip_idx_list: + continue + if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended: + indices.append(idx) + + return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices) def _accept_token(self, idx: int, token: int): """accept token""" @@ -151,6 +159,15 @@ class SamplerProcessor: if self.logits_processor[idx].is_terminated(): return + if ( + self.reasoning_parser is not None + and self.logits_processor[idx].enable_reasoning + and not self.logits_processor[idx].reasoning_ended + ): + reasoning_ended = self.reasoning_parser.is_reasoning_end([token]) + self.logits_processor[idx].reasoning_ended = reasoning_ended + return + self.logits_processor[idx].accept_token(token) def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): @@ -206,12 +223,11 @@ class Sampler(nn.Layer): self.early_stopper = early_stopper_cls() self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config) - def apply_logits_processor( - self, - ids: int, - future: Optional[Any] = None, - prefill_tokens: List[int] = [], - ): + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """set reasoning parser""" + self.processor.apply_reasoning_parser(reasoning_parser) + + def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): """apply logits processor to sampler""" self.processor.add_logits_processor(ids, future, prefill_tokens) @@ -219,6 +235,10 @@ class Sampler(nn.Layer): """pre process before running""" self.processor.pre_process(skip_idx_list) + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """post process after running""" + self.processor.update_output_tokens(next_tokens, skip_idx_list) + def compute_logprobs( self, logits: paddle.Tensor, @@ -307,12 +327,12 @@ class Sampler(nn.Layer): skip_idx_list: List[int] = [], ) -> SamplerOutput: """ """ + logits = self.processor.apply_token_mask(logits, skip_idx_list) + num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: raw_logprobs = self.compute_logprobs(logits, sampling_metadata) - logits = self.processor.apply_token_mask(logits, skip_idx_list) - logits = apply_penalty_multi_scores( sampling_metadata.pre_token_ids, sampling_metadata.prompt_ids, @@ -347,8 +367,6 @@ class Sampler(nn.Layer): assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop" self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags) - self.processor.update_output_tokens(next_tokens, skip_idx_list) - sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated @@ -380,12 +398,15 @@ class SpeculativeSampler(nn.Layer): """pre process before running""" pass - def apply_logits_processor( - self, - ids: int, - future: Optional[Any] = None, - prefill_tokens: List[int] = [], - ): + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """set reasoning parser""" + pass + + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """post process after running""" + pass + + def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []): """apply logits processor to sampler""" pass @@ -480,6 +501,14 @@ class MTPSampler(nn.Layer): """apply logits processor to sampler""" pass + def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None): + """set reasoning parser""" + pass + + def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []): + """post process after running""" + pass + def forward_cuda( self, logits: paddle.Tensor, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6b042fbbe..42848c380 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -29,9 +29,9 @@ from fastdeploy.model_executor.graph_optimization.utils import ( profile_run_guard, sot_warmup_guard, ) -from fastdeploy.model_executor.guided_decoding import get_guided_backend -from fastdeploy.model_executor.guided_decoding.base_guided_decoding import ( +from fastdeploy.model_executor.guided_decoding import ( LogitsProcessorBase, + get_guided_backend, ) from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -97,10 +97,6 @@ class GPUModelRunner(ModelRunnerBase): self.enable_logprob = fd_config.model_config.enable_logprob self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop - self.guided_backend = None - if self.fd_config.parallel_config.guided_decoding_backend != "off": - self.guided_backend = get_guided_backend(fd_config=self.fd_config) - # VL model config: if self.enable_mm: if "ernie" in self.fd_config.model_config.model_type: @@ -129,6 +125,11 @@ class GPUModelRunner(ModelRunnerBase): else: self.sampler = SpeculativeSampler(fd_config) + self.guided_backend = None + if self.fd_config.parallel_config.guided_decoding_backend != "off": + self.guided_backend = get_guided_backend(fd_config=self.fd_config) + self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser()) + # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] @@ -206,7 +207,16 @@ class GPUModelRunner(ModelRunnerBase): elif request.structural_tag is not None: schemata_key = ("structural_tag", request.structural_tag) - return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key + enable_thinking = request.get("enable_thinking", True) + enable_thinking = enable_thinking if enable_thinking is not None else True + + return ( + self.guided_backend.get_logits_processor( + schemata_key=schemata_key, + enable_thinking=enable_thinking, + ), + schemata_key, + ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -1336,10 +1346,10 @@ class GPUModelRunner(ModelRunnerBase): Returns: A list of indices corresponding to the requests that need to be skipped. """ - skip_idx_list = [] - if not self.cache_config.enable_chunked_prefill or self.guided_backend is None: - return skip_idx_list + if not self.cache_config.enable_chunked_prefill or self.guided_backend is None or model_forward_batch is None: + return [] + skip_idx_list = [] for task in model_forward_batch: if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info): continue @@ -1505,6 +1515,8 @@ class GPUModelRunner(ModelRunnerBase): speculative_decoding=self.speculative_decoding, skip_save_output=skip_save_output, ) + if self.guided_backend is not None and sampler_output is not None: + self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) # 6. Speculative decode if self.speculative_decoding: @@ -1538,7 +1550,7 @@ class GPUModelRunner(ModelRunnerBase): """ Add cache for guided decoding. """ - if self.guided_backend is None: + if self.guided_backend is None or model_forward_batch is None: return for request in model_forward_batch: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index a57195391..28b883662 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -590,6 +590,12 @@ def parse_args(): action="store_true", help="Enable output of token-level log probabilities.", ) + parser.add_argument( + "--reasoning_parser", + type=str, + default=None, + help="Flag specifies the reasoning parser to use for extracting reasoning content from the model output", + ) parser.add_argument( "--early_stop_config", type=json.loads, diff --git a/scripts/run_pre_ce.sh b/scripts/run_pre_ce.sh index 67b06736e..ab36dac96 100644 --- a/scripts/run_pre_ce.sh +++ b/scripts/run_pre_ce.sh @@ -7,6 +7,7 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p python -m pip install -r requirements.txt python -m pip install jsonschema aistudio_sdk==0.3.5 +python -m pip install xgrammar==0.1.19 torch==2.6.0 failed_files=() run_path="$DIR/../tests/ci_use/" diff --git a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py index a5688c866..8cba35eb1 100644 --- a/tests/ci_use/EB_Lite/test_EB_Lite_serving.py +++ b/tests/ci_use/EB_Lite/test_EB_Lite_serving.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import re import shutil @@ -110,6 +111,8 @@ def setup_and_run_server(): "--use-cudagraph", "--graph-optimization-config", '{"cudagraph_capture_sizes": [1]}', + "--guided-decoding-backend", + "auto", ] # Start subprocess in new process group @@ -1142,3 +1145,336 @@ def test_profile_reset_block_num(): f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内" f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]" ) + + +def streaming_chat_base(openai_client, chat_param): + """ + Test streaming chat base functionality with the local service + """ + assert isinstance(chat_param, dict), f"{chat_param} should be a dict" + assert "messages" in chat_param, f"{chat_param} should contain messages" + + response = openai_client.chat.completions.create( + model="default", + stream=True, + **chat_param, + ) + + output = [] + for chunk in response: + if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"): + output.append(chunk.choices[0].delta.content) + assert len(output) > 2 + return "".join(output) + + +def non_streaming_chat_base(openai_client, chat_param): + """ + Test non streaming chat base functionality with the local service + """ + assert isinstance(chat_param, dict), f"{chat_param} should be a dict" + assert "messages" in chat_param, f"{chat_param} should contain messages" + + response = openai_client.chat.completions.create( + model="default", + stream=False, + **chat_param, + ) + + assert hasattr(response, "choices") + assert len(response.choices) > 0 + assert hasattr(response.choices[0], "message") + assert hasattr(response.choices[0].message, "content") + return response.choices[0].message.content + + +def test_structured_outputs_json_schema(openai_client): + """ + Test structured outputs json_schema functionality with the local service + """ + chat_param = { + "temperature": 1, + "max_tokens": 1024, + } + + # json_object + json_chat_param = { + "messages": [ + { + "role": "user", + "content": "Generate a JSON object containing: names of China's Four Great Inventions, their dynasties of origin, and brief descriptions (each under 50 characters)", + } + ], + "response_format": {"type": "json_object"}, + } + json_chat_param.update(chat_param) + + response = streaming_chat_base(openai_client, json_chat_param) + try: + json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema streaming response: {response} is not a valid json" + + response = non_streaming_chat_base(openai_client, json_chat_param) + try: + json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema non_streaming response: {response} is not a valid json" + + # json_schema + from enum import Enum + + from pydantic import BaseModel + + class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + + class BookDescription(BaseModel): + author: str + title: str + genre: BookType + + json_schema_param = { + "messages": [ + { + "role": "user", + "content": "Generate a JSON describing a literary work, including author, title and book type.", + } + ], + "response_format": { + "type": "json_schema", + "json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()}, + }, + } + json_schema_param.update(chat_param) + response = streaming_chat_base(openai_client, json_schema_param) + try: + json_schema_response = json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema streaming response: {response} is not a valid json" + assert ( + "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response + ), f"json_schema streaming response: {response} is not a valid book-description" + assert json_schema_response["genre"] in { + genre.value for genre in BookType + }, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type" + + response = non_streaming_chat_base(openai_client, json_schema_param) + try: + json_schema_response = json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema non_streaming response: {response} is not a valid json" + assert ( + "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response + ), f"json_schema non_streaming response: {response} is not a valid book-description" + assert json_schema_response["genre"] in { + genre.value for genre in BookType + }, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type" + + +def test_structured_outputs_structural_tag(openai_client): + """ + Test structured outputs structural_tag functionality with the local service + """ + content_str = """ + You have the following function available: + + { + "name": "get_current_date", + "description": "Get current date and time for given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone to get current date/time, e.g.: Asia/Shanghai", + } + }, + "required": ["timezone"], + } + } + + If you choose to call only this function, reply in this format: + <{start_tag}={function_name}>{parameters}{end_tag} + where: + + start_tag => ` JSON dictionary with parameter names as keys + end_tag => `` + + Example: + {"param": "value"} + + Note: + - Function call must follow specified format + - Required parameters must be specified + - Only one function can be called at a time + - Place entire function call response on a single line + + You are an AI assistant. Answer the following question. + """ + + structural_tag_param = { + "temperature": 1, + "max_tokens": 1024, + "messages": [ + { + "role": "system", + "content": content_str, + }, + { + "role": "user", + "content": "You're traveling to Shanghai today", + }, + ], + "response_format": { + "type": "structural_tag", + "structures": [ + { + "begin": "", + "schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone to get current date/time, e.g.: Asia/Shanghai", + } + }, + "required": ["timezone"], + }, + "end": "", + } + ], + "triggers": ["" text "" + + style_attribute ::= " style=" dq style_value dq + + style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?) + + font_style ::= "font-family: '" font_name "'" + + font_weight ::= "font-weight: " weight_value + + font_name ::= "Arial" | "Times New Roman" | "Courier New" + + weight_value ::= "normal" | "bold" + + text ::= [A-Za-z0-9 ]+ + + dq ::= ["] + """ + + grammar_param = { + "temperature": 1, + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot", + } + ], + "extra_body": {"guided_grammar": html_h1_grammar}, + } + + import re + + pattern = r'^[A-Za-z0-9 ]+$' + response = streaming_chat_base(openai_client, grammar_param) + assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected" + response = non_streaming_chat_base(openai_client, grammar_param) + assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected" diff --git a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py index 6eb78345d..fed911861 100644 --- a/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py +++ b/tests/ci_use/EB_VL_Lite/test_EB_VL_Lite_serving.py @@ -119,6 +119,8 @@ def setup_and_run_server(): "wint4", "--reasoning-parser", "ernie-45-vl", + "--guided-decoding-backend", + "auto", ] # Start subprocess in new process group @@ -540,6 +542,348 @@ def test_chat_with_thinking(openai_client, capsys): assert reasoning_tokens <= reasoning_max_tokens +def streaming_chat_base(openai_client, chat_param): + """ + Test streaming chat base functionality with the local service + """ + assert isinstance(chat_param, dict), f"{chat_param} should be a dict" + assert "messages" in chat_param, f"{chat_param} should contain messages" + + response = openai_client.chat.completions.create( + model="default", + stream=True, + **chat_param, + ) + + output = [] + for chunk in response: + if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"): + output.append(chunk.choices[0].delta.content) + assert len(output) > 2 + return "".join(output) + + +def non_streaming_chat_base(openai_client, chat_param): + """ + Test non streaming chat base functionality with the local service + """ + assert isinstance(chat_param, dict), f"{chat_param} should be a dict" + assert "messages" in chat_param, f"{chat_param} should contain messages" + + response = openai_client.chat.completions.create( + model="default", + stream=False, + **chat_param, + ) + + assert hasattr(response, "choices") + assert len(response.choices) > 0 + assert hasattr(response.choices[0], "message") + assert hasattr(response.choices[0].message, "content") + return response.choices[0].message.content + + +def test_structured_outputs_json_schema(openai_client): + """ + Test structured outputs json_schema functionality with the local service + """ + chat_param = { + "temperature": 1, + "max_tokens": 1024, + } + + # json_object + json_chat_param = { + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant."}, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg", + "detail": "high", + }, + }, + {"type": "text", "text": "请描述图片内容,使用json格式输出结果"}, + ], + }, + ], + "response_format": {"type": "json_object"}, + } + json_chat_param.update(chat_param) + + outputs = [] + outputs.append(streaming_chat_base(openai_client, json_chat_param)) + outputs.append(non_streaming_chat_base(openai_client, json_chat_param)) + + json_chat_param["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}} + outputs.append(streaming_chat_base(openai_client, json_chat_param)) + outputs.append(non_streaming_chat_base(openai_client, json_chat_param)) + + for response in outputs: + try: + json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_object response: {response} is not a valid json" + + # json_schema + from enum import Enum + + from pydantic import BaseModel + + class BookType(str, Enum): + romance = "Romance" + historical = "Historical" + adventure = "Adventure" + mystery = "Mystery" + dystopian = "Dystopian" + + class BookDescription(BaseModel): + author: str + title: str + genre: BookType + + json_schema_param = { + "messages": [ + { + "role": "user", + "content": "Generate a JSON describing a literary work, including author, title and book type.", + } + ], + "response_format": { + "type": "json_schema", + "json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()}, + }, + } + json_schema_param.update(chat_param) + response = streaming_chat_base(openai_client, json_schema_param) + try: + json_schema_response = json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema streaming response: {response} is not a valid json" + assert ( + "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response + ), f"json_schema streaming response: {response} is not a valid book-description" + assert json_schema_response["genre"] in { + genre.value for genre in BookType + }, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type" + + response = non_streaming_chat_base(openai_client, json_schema_param) + try: + json_schema_response = json.loads(response) + is_valid = True + except ValueError: + is_valid = False + + assert is_valid, f"json_schema non_streaming response: {response} is not a valid json" + assert ( + "author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response + ), f"json_schema non_streaming response: {response} is not a valid book-description" + assert json_schema_response["genre"] in { + genre.value for genre in BookType + }, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type" + + +def test_structured_outputs_structural_tag(openai_client): + """ + Test structured outputs structural_tag functionality with the local service + """ + content_str = """ + You have the following function available: + + { + "name": "get_current_date", + "description": "Get current date and time for given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone to get current date/time, e.g.: Asia/Shanghai", + } + }, + "required": ["timezone"], + } + } + + If you choose to call only this function, reply in this format: + <{start_tag}={function_name}>{parameters}{end_tag} + where: + + start_tag => ` JSON dictionary with parameter names as keys + end_tag => `` + + Example: + {"param": "value"} + + Note: + - Function call must follow specified format + - Required parameters must be specified + - Only one function can be called at a time + - Place entire function call response on a single line + + You are an AI assistant. Answer the following question. + """ + + structural_tag_param = { + "temperature": 1, + "max_tokens": 1024, + "messages": [ + { + "role": "system", + "content": content_str, + }, + { + "role": "user", + "content": "You're traveling to Shanghai today", + }, + ], + "response_format": { + "type": "structural_tag", + "structures": [ + { + "begin": "", + "schema": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone to get current date/time, e.g.: Asia/Shanghai", + } + }, + "required": ["timezone"], + }, + "end": "", + } + ], + "triggers": ["" text "" + + style_attribute ::= " style=" dq style_value dq + + style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?) + + font_style ::= "font-family: '" font_name "'" + + font_weight ::= "font-weight: " weight_value + + font_name ::= "Arial" | "Times New Roman" | "Courier New" + + weight_value ::= "normal" | "bold" + + text ::= [A-Za-z0-9 ]+ + + dq ::= ["] + """ + + grammar_param = { + "temperature": 1, + "max_tokens": 1024, + "messages": [ + { + "role": "user", + "content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot", + } + ], + "extra_body": {"guided_grammar": html_h1_grammar}, + } + + import re + + pattern = r'^[A-Za-z0-9 ]+$' + response = streaming_chat_base(openai_client, grammar_param) + assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected" + response = non_streaming_chat_base(openai_client, grammar_param) + assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected" + + def test_profile_reset_block_num(): """测试profile reset_block_num功能,与baseline diff不能超过5%""" log_file = "./log/config.log"