mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[Feature] mm and thinking model support structred output (#2749)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* mm support structured output * update code * update code * update format * update code * update code * add enable_thinking default * update code * add structured_outputs test case * add ci install xgrammar * add ci timeout time * update test for structured_outputs * update code * add error traceback info * update error msg * update structred output code * update code * update code * update config * update torch version --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -330,3 +330,65 @@ ParsedChatCompletionMessage[Info](content='{"addr": "No.1 Century Avenue, Pudong
|
|||||||
Address: No.1 Century Avenue, Pudong New Area, Shanghai
|
Address: No.1 Century Avenue, Pudong New Area, Shanghai
|
||||||
Height: 468
|
Height: 468
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Offline Inference
|
||||||
|
|
||||||
|
Offline inference allows restricting the model's output format by pre-specified constraints. In `FastDeploy`, constraints can be specified through the `GuidedDecodingParams` class in `SamplingParams`. `GuidedDecodingParams` supports the following constraint types, with usage similar to online inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
json: Optional[Union[str, dict]] = None
|
||||||
|
regex: Optional[str] = None
|
||||||
|
choice: Optional[List[str]] = None
|
||||||
|
grammar: Optional[str] = None
|
||||||
|
json_object: Optional[bool] = None
|
||||||
|
structural_tag: Optional[str] = None
|
||||||
|
```
|
||||||
|
|
||||||
|
The following example demonstrates how to use offline inference to generate a structured json:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastdeploy import LLM, SamplingParams
|
||||||
|
from fastdeploy.engine.sampling_params import GuidedDecodingParams
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class BookType(str, Enum):
|
||||||
|
romance = "Romance"
|
||||||
|
historical = "Historical"
|
||||||
|
adventure = "Adventure"
|
||||||
|
mystery = "Mystery"
|
||||||
|
dystopian = "Dystopian"
|
||||||
|
|
||||||
|
class BookDescription(BaseModel):
|
||||||
|
author: str
|
||||||
|
title: str
|
||||||
|
genre: BookType
|
||||||
|
|
||||||
|
# Constrained decoding parameters
|
||||||
|
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
|
||||||
|
|
||||||
|
# Sampling parameters
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=6400,
|
||||||
|
guided_decoding=guided_decoding_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts="Generate a JSON describing a literary work, including author, title and book type.",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output results
|
||||||
|
for output in outputs:
|
||||||
|
print(output.outputs.text)
|
||||||
|
```
|
||||||
|
|
||||||
|
Output:
|
||||||
|
|
||||||
|
```
|
||||||
|
{"author": "George Orwell", "title": "1984", "genre": "Dystopian"}
|
||||||
|
```
|
||||||
|
@@ -330,3 +330,67 @@ ParsedChatCompletionMessage[Info](content='{"addr": "上海市浦东新区世纪
|
|||||||
地址: 上海市浦东新区世纪大道1号
|
地址: 上海市浦东新区世纪大道1号
|
||||||
高度: 468
|
高度: 468
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### 离线推理
|
||||||
|
|
||||||
|
离线推理允许通过预先指定约束条件,限制模型输出格式。在 `FastDeploy` 中,支持通过 `SamplingParams` 中的 `GuidedDecodingParams` 类指定相关约束条件。`GuidedDecodingParams` 支持以下几种约束条件,使用方式可以参考在线推理:
|
||||||
|
|
||||||
|
```python
|
||||||
|
json: Optional[Union[str, dict]] = None
|
||||||
|
regex: Optional[str] = None
|
||||||
|
choice: Optional[List[str]] = None
|
||||||
|
grammar: Optional[str] = None
|
||||||
|
json_object: Optional[bool] = None
|
||||||
|
structural_tag: Optional[str] = None
|
||||||
|
```
|
||||||
|
|
||||||
|
以下示例展示了如何使用离线推理生成一个结构化的 json :
|
||||||
|
|
||||||
|
```python
|
||||||
|
|
||||||
|
from fastdeploy import LLM, SamplingParams
|
||||||
|
from fastdeploy.engine.sampling_params import GuidedDecodingParams
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class BookType(str, Enum):
|
||||||
|
romance = "Romance"
|
||||||
|
historical = "Historical"
|
||||||
|
adventure = "Adventure"
|
||||||
|
mystery = "Mystery"
|
||||||
|
dystopian = "Dystopian"
|
||||||
|
|
||||||
|
class BookDescription(BaseModel):
|
||||||
|
author: str
|
||||||
|
title: str
|
||||||
|
genre: BookType
|
||||||
|
|
||||||
|
# Constrained decoding parameters
|
||||||
|
guided_decoding_params = GuidedDecodingParams(json=BookDescription.model_json_schema())
|
||||||
|
|
||||||
|
# Sampling parameters
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=6400,
|
||||||
|
guided_decoding=guided_decoding_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
llm = LLM(model="ERNIE-4.5-0.3B", tensor_parallel_size=1, max_model_len=8192, guided_decoding_backend="auto")
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts="生成一个JSON,描述一本中国的著作,要包含作者、标题和书籍类型。",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output results
|
||||||
|
for output in outputs:
|
||||||
|
print(output.outputs.text)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
输出
|
||||||
|
|
||||||
|
```
|
||||||
|
{"author": "曹雪芹", "title": "红楼梦", "genre": "Historical"}
|
||||||
|
```
|
||||||
|
@@ -127,12 +127,13 @@ class ModelConfig:
|
|||||||
self.redundant_experts_num = 0
|
self.redundant_experts_num = 0
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
self.quantization = None
|
self.quantization = None
|
||||||
|
self.reasoning_parser = None
|
||||||
self.pad_token_id: int = -1
|
self.pad_token_id: int = -1
|
||||||
self.eos_tokens_lens: int = 2
|
self.eos_tokens_lens: int = 2
|
||||||
self.lm_head_fp32: bool = False
|
self.lm_head_fp32: bool = False
|
||||||
self.model_format = "auto"
|
self.model_format = "auto"
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key) and value != "None":
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
assert self.model != ""
|
assert self.model != ""
|
||||||
@@ -1249,7 +1250,8 @@ class FDConfig:
|
|||||||
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
self.cache_config.max_block_num_per_seq = int(self.max_model_len // self.cache_config.block_size)
|
||||||
|
|
||||||
if self.guided_decoding_backend == "auto":
|
if self.guided_decoding_backend == "auto":
|
||||||
if self.model_config.enable_mm:
|
if current_platform.is_xpu() or self.speculative_config.method is not None:
|
||||||
|
logger.warning("Speculative Decoding and XPU currently do not support Guided decoding, set off.")
|
||||||
self.guided_decoding_backend = "off"
|
self.guided_decoding_backend = "off"
|
||||||
else:
|
else:
|
||||||
self.guided_decoding_backend = "xgrammar"
|
self.guided_decoding_backend = "xgrammar"
|
||||||
@@ -1319,12 +1321,10 @@ class FDConfig:
|
|||||||
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
], f"Only support xgrammar、auto guided decoding backend, but got {self.guided_decoding_backend}."
|
||||||
|
|
||||||
if self.guided_decoding_backend != "off":
|
if self.guided_decoding_backend != "off":
|
||||||
# TODO: mm support guided_decoding
|
|
||||||
assert (
|
|
||||||
self.model_config.enable_mm is False
|
|
||||||
), "Multimodal model currently do not support guided_decoding"
|
|
||||||
|
|
||||||
# TODO: speculative decoding support guided_decoding
|
# TODO: speculative decoding support guided_decoding
|
||||||
|
assert (
|
||||||
|
self.speculative_config.method is None
|
||||||
|
), "speculative decoding currently do not support guided_decoding"
|
||||||
|
|
||||||
# TODO: xpu support guided_decoding
|
# TODO: xpu support guided_decoding
|
||||||
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
|
assert not current_platform.is_xpu(), "XPU currently do not support guided_decoding"
|
||||||
@@ -1335,6 +1335,7 @@ class FDConfig:
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
f"import XGrammar failed, please install XGrammar use `pip install xgrammar==0.1.19`. \n\t {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.scheduler_config is not None:
|
if self.scheduler_config is not None:
|
||||||
self.scheduler_config.check()
|
self.scheduler_config.check()
|
||||||
|
|
||||||
|
@@ -178,6 +178,22 @@ class LLMEngine:
|
|||||||
|
|
||||||
# _insert_task_to_worker moved to CommonEngine
|
# _insert_task_to_worker moved to CommonEngine
|
||||||
|
|
||||||
|
def _has_guided_input(self, request):
|
||||||
|
"""
|
||||||
|
Check if the request has any guided input.
|
||||||
|
"""
|
||||||
|
return any(
|
||||||
|
x is not None
|
||||||
|
for x in (
|
||||||
|
request.guided_json,
|
||||||
|
request.guided_regex,
|
||||||
|
request.guided_choice,
|
||||||
|
request.structural_tag,
|
||||||
|
request.guided_grammar,
|
||||||
|
request.guided_json_object,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def add_requests(self, task, sampling_params=None, **kwargs):
|
def add_requests(self, task, sampling_params=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Add a new request to the queue.
|
Add a new request to the queue.
|
||||||
@@ -249,8 +265,15 @@ class LLMEngine:
|
|||||||
llm_logger.error(error_msg)
|
llm_logger.error(error_msg)
|
||||||
raise EngineError(error_msg, error_code=400)
|
raise EngineError(error_msg, error_code=400)
|
||||||
|
|
||||||
if self.engine.guided_decoding_checker is not None:
|
if self._has_guided_input(request):
|
||||||
request, err_msg = self.engine.guided_decoding_checker.schema_format(request)
|
err_msg = None
|
||||||
|
if self.guided_decoding_checker is None:
|
||||||
|
err_msg = (
|
||||||
|
"guided_backend is None, use --guided-decoding-backend to specify the backend at server startup."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request, err_msg = self.guided_decoding_checker.schema_format(request)
|
||||||
|
|
||||||
if err_msg is not None:
|
if err_msg is not None:
|
||||||
llm_logger.error(err_msg)
|
llm_logger.error(err_msg)
|
||||||
raise EngineError(err_msg, error_code=400)
|
raise EngineError(err_msg, error_code=400)
|
||||||
@@ -469,6 +492,7 @@ class LLMEngine:
|
|||||||
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
|
||||||
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
f" --load_strategy {self.cfg.load_config.load_strategy}"
|
||||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
|
f" --reasoning_parser {self.cfg.reasoning_parser}"
|
||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
||||||
f" --ips {ips}"
|
f" --ips {ips}"
|
||||||
|
@@ -263,13 +263,11 @@ class Request:
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
non_none_fields = []
|
||||||
f"Request(request_id={self.request_id}, "
|
for attr, value in vars(self).items():
|
||||||
f"prompt={self.prompt!r}, "
|
if value is not None and not attr.startswith("_"):
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
non_none_fields.append(f"{attr}={value!r}")
|
||||||
f"draft_token_ids={self.draft_token_ids}, "
|
return f"Request({', '.join(non_none_fields)})"
|
||||||
f"sampling_params={self.sampling_params})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
|
@@ -100,6 +100,7 @@ class SamplingParams:
|
|||||||
temp_scaled_logprobs: bool = False
|
temp_scaled_logprobs: bool = False
|
||||||
top_p_normalized_logprobs: bool = False
|
top_p_normalized_logprobs: bool = False
|
||||||
bad_words: Optional[List[str]] = None
|
bad_words: Optional[List[str]] = None
|
||||||
|
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||||
bad_words_token_ids: Optional[List[int]] = None
|
bad_words_token_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -132,6 +133,7 @@ class SamplingParams:
|
|||||||
min_tokens=1,
|
min_tokens=1,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
bad_words=None,
|
bad_words=None,
|
||||||
|
guided_decoding=None,
|
||||||
bad_words_token_ids=None,
|
bad_words_token_ids=None,
|
||||||
) -> SamplingParams:
|
) -> SamplingParams:
|
||||||
"""Create instance from command line arguments"""
|
"""Create instance from command line arguments"""
|
||||||
@@ -153,6 +155,7 @@ class SamplingParams:
|
|||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
bad_words=bad_words,
|
bad_words=bad_words,
|
||||||
|
guided_decoding=guided_decoding,
|
||||||
bad_words_token_ids=bad_words_token_ids,
|
bad_words_token_ids=bad_words_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -217,3 +220,51 @@ class BeamSearchParams:
|
|||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
length_penalty: float = 1.0
|
length_penalty: float = 1.0
|
||||||
include_stop_str_in_output: bool = False
|
include_stop_str_in_output: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuidedDecodingParams:
|
||||||
|
"""Guided decoding parameters for text generation."""
|
||||||
|
|
||||||
|
json: Optional[Union[str, dict]] = None
|
||||||
|
regex: Optional[str] = None
|
||||||
|
choice: Optional[List[str]] = None
|
||||||
|
grammar: Optional[str] = None
|
||||||
|
json_object: Optional[bool] = None
|
||||||
|
structural_tag: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""convert to dict"""
|
||||||
|
key_dict = {
|
||||||
|
"guided_json": self.json,
|
||||||
|
"guided_regex": self.regex,
|
||||||
|
"guided_choice": self.choice,
|
||||||
|
"guided_grammar": self.grammar,
|
||||||
|
"structural_tag": self.structural_tag,
|
||||||
|
"guided_json_object": self.json_object,
|
||||||
|
}
|
||||||
|
|
||||||
|
guided_dict = {}
|
||||||
|
for key, value in key_dict.items():
|
||||||
|
if value is not None:
|
||||||
|
guided_dict[key] = value
|
||||||
|
return guided_dict
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Verify the arguments."""
|
||||||
|
guided_count = sum(
|
||||||
|
[
|
||||||
|
self.json is not None,
|
||||||
|
self.regex is not None,
|
||||||
|
self.choice is not None,
|
||||||
|
self.grammar is not None,
|
||||||
|
self.json_object is not None,
|
||||||
|
self.structural_tag is not None,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if guided_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding "
|
||||||
|
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
|
||||||
|
)
|
||||||
|
@@ -295,6 +295,9 @@ class LLM:
|
|||||||
current_sampling_params = sampling_params[i]
|
current_sampling_params = sampling_params[i]
|
||||||
else:
|
else:
|
||||||
current_sampling_params = sampling_params
|
current_sampling_params = sampling_params
|
||||||
|
if current_sampling_params.guided_decoding is not None:
|
||||||
|
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
|
||||||
|
tasks.update(guided_decoding_dict)
|
||||||
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
|
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
|
||||||
return req_ids
|
return req_ids
|
||||||
|
|
||||||
|
@@ -15,8 +15,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# from fastdeploy.config import FDConfig
|
# from fastdeploy.config import FDConfig
|
||||||
|
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
||||||
|
BackendBase,
|
||||||
|
BaseChecker,
|
||||||
|
LogitsProcessorBase,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["get_guided_backend", "schema_checker"]
|
__all__ = ["get_guided_backend", "schema_checker", "LogitsProcessorBase", "BackendBase", "BaseChecker"]
|
||||||
|
|
||||||
|
|
||||||
def get_guided_backend(
|
def get_guided_backend(
|
||||||
|
@@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
from fastdeploy.config import ErnieArchitectures, FDConfig
|
from fastdeploy.config import ErnieArchitectures, FDConfig
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
|
from fastdeploy.reasoning import ReasoningParserManager
|
||||||
from fastdeploy.utils import llm_logger
|
from fastdeploy.utils import llm_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -35,8 +36,9 @@ class LogitsProcessorBase:
|
|||||||
None (all state should be managed by subclasses)
|
None (all state should be managed by subclasses)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, enable_reasoning):
|
||||||
pass
|
self.reasoning_ended = False
|
||||||
|
self.enable_reasoning = enable_reasoning
|
||||||
|
|
||||||
def fill_token_bitmask(self, token_bitmask, idx):
|
def fill_token_bitmask(self, token_bitmask, idx):
|
||||||
"""
|
"""
|
||||||
@@ -137,8 +139,14 @@ class BackendBase:
|
|||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
self.max_cache_size = 2048
|
self.max_cache_size = 2048
|
||||||
|
self.reasoning_parser = None
|
||||||
|
|
||||||
self.hf_tokenizer = self._get_tokenizer_hf()
|
self.hf_tokenizer = self._get_tokenizer_hf()
|
||||||
|
if self.fd_config.model_config.reasoning_parser:
|
||||||
|
reasoning_parser_obj = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
self.fd_config.model_config.reasoning_parser
|
||||||
|
)
|
||||||
|
self.reasoning_parser = reasoning_parser_obj(self.hf_tokenizer)
|
||||||
|
|
||||||
def _create_processor(self):
|
def _create_processor(self):
|
||||||
"""
|
"""
|
||||||
@@ -149,70 +157,88 @@ class BackendBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _json_processor(self, schemata):
|
def _json_processor(self, schemata, enable_thinking=False):
|
||||||
"""
|
"""
|
||||||
Process JSON schemata.
|
Process JSON schemata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): The schemata string.
|
schemata (str): The schemata string.
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: This method should be implemented in subclasses.
|
NotImplementedError: This method should be implemented in subclasses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _regex_processor(self, schemata):
|
def _regex_processor(self, schemata, enable_thinking=False):
|
||||||
"""
|
"""
|
||||||
Process regular expression schemata.
|
Process regular expression schemata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): The schemata string.
|
schemata (str): The schemata string.
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: This method should be implemented in subclasses.
|
NotImplementedError: This method should be implemented in subclasses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _grammar_processor(self, schemata):
|
def _grammar_processor(self, schemata, enable_thinking=False):
|
||||||
"""
|
"""
|
||||||
Process grammar schemata.
|
Process grammar schemata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): The schemata string.
|
schemata (str): The schemata string.
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: This method should be implemented in subclasses.
|
NotImplementedError: This method should be implemented in subclasses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _structural_tag_processor(self, schemata):
|
def _structural_tag_processor(self, schemata, enable_thinking=False):
|
||||||
"""
|
"""
|
||||||
Process structural tag schemata.
|
Process structural tag schemata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): The schemata string.
|
schemata (str): The schemata string.
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: This method should be implemented in subclasses.
|
NotImplementedError: This method should be implemented in subclasses.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _unsupported_processor_type(self, key_type, schemata):
|
def _unsupported_processor_type(self, key_type, schemata, enable_thinking=False):
|
||||||
"""
|
"""
|
||||||
Process unsupported type.
|
Process unsupported type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key_type (str): The key type string.
|
key_type (str): The key type string.
|
||||||
schemata (str): The schemata string.
|
schemata (str): The schemata string.
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode.
|
||||||
"""
|
"""
|
||||||
raise Exception(f"Unsupported processor type {key_type}.")
|
raise Exception(f"Unsupported processor type {key_type}.")
|
||||||
|
|
||||||
def _init_logits_processor(self, schemata_key: tuple[str, str]) -> LogitsProcessorBase:
|
def get_reasoning_parser(self):
|
||||||
|
"""
|
||||||
|
Get reasoning parser object.
|
||||||
|
Returns:
|
||||||
|
ReasoningParser: Reasoning parser object or None
|
||||||
|
"""
|
||||||
|
return self.reasoning_parser
|
||||||
|
|
||||||
|
def _init_logits_processor(
|
||||||
|
self,
|
||||||
|
schemata_key: tuple[str, str],
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> LogitsProcessorBase:
|
||||||
"""
|
"""
|
||||||
init logits processor by type and schemata.
|
init logits processor by type and schemata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
schemata_key (tuple[str, str]): Tuple containing processor type and schema string
|
||||||
|
enable_thinking (bool): Whether to enable thinking step
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LogitsProcessorBase: Initialized logits processor instance
|
LogitsProcessorBase: Initialized logits processor instance
|
||||||
@@ -222,18 +248,22 @@ class BackendBase:
|
|||||||
"""
|
"""
|
||||||
key_type, schemata = schemata_key
|
key_type, schemata = schemata_key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
return self._json_processor(schemata)
|
return self._json_processor(schemata, enable_thinking)
|
||||||
elif key_type == "regex":
|
elif key_type == "regex":
|
||||||
return self._regex_processor(schemata)
|
return self._regex_processor(schemata, enable_thinking)
|
||||||
elif key_type == "grammar":
|
elif key_type == "grammar":
|
||||||
return self._grammar_processor(schemata)
|
return self._grammar_processor(schemata, enable_thinking)
|
||||||
elif key_type == "structural_tag":
|
elif key_type == "structural_tag":
|
||||||
return self._structural_tag_processor(schemata)
|
return self._structural_tag_processor(schemata, enable_thinking)
|
||||||
else:
|
else:
|
||||||
llm_logger.error(f"Unsupported processor type {key_type}.")
|
llm_logger.error(f"Unsupported processor type {key_type}.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_logits_processor(self, schemata_key: tuple[str, str]) -> tuple[LogitsProcessorBase, bool]:
|
def get_logits_processor(
|
||||||
|
self,
|
||||||
|
schemata_key: tuple[str, str],
|
||||||
|
enable_thinking: bool = False,
|
||||||
|
) -> tuple[LogitsProcessorBase, bool]:
|
||||||
"""
|
"""
|
||||||
get logits processor by key from cache or create new one.
|
get logits processor by key from cache or create new one.
|
||||||
|
|
||||||
@@ -247,8 +277,10 @@ class BackendBase:
|
|||||||
"""
|
"""
|
||||||
value = self.cache.get(schemata_key, None)
|
value = self.cache.get(schemata_key, None)
|
||||||
if value:
|
if value:
|
||||||
return value.copy(), True
|
value_copy = value.copy()
|
||||||
value = self.executor.submit(self._init_logits_processor, schemata_key)
|
value_copy.enable_reasoning = enable_thinking
|
||||||
|
return value_copy, True
|
||||||
|
value = self.executor.submit(self._init_logits_processor, schemata_key, enable_thinking)
|
||||||
return value, False
|
return value, False
|
||||||
|
|
||||||
def _get_tokenizer_hf(self):
|
def _get_tokenizer_hf(self):
|
||||||
@@ -267,7 +299,6 @@ class BackendBase:
|
|||||||
try:
|
try:
|
||||||
architectures = self.fd_config.model_config.architectures
|
architectures = self.fd_config.model_config.architectures
|
||||||
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
if not ErnieArchitectures.contains_ernie_arch(architectures):
|
||||||
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# refer to https://github.com/mlc-ai/xgrammar/blob/main/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError("Triton is not installed") from err
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def apply_token_bitmask_inplace_kernel(
|
||||||
|
logits_ptr,
|
||||||
|
bitmask_ptr,
|
||||||
|
indices_ptr,
|
||||||
|
num_rows,
|
||||||
|
vocab_size,
|
||||||
|
logits_strides,
|
||||||
|
bitmask_strides,
|
||||||
|
NUM_SMS: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Triton kernel for in-place logits masking using bitwise compression.
|
||||||
|
|
||||||
|
Processes logits tensor in blocks, applying bitmask to restrict vocabulary access.
|
||||||
|
Masked positions are set to -inf to ensure zero probability during sampling.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Bitmask uses 32:1 compression (1 bit per vocabulary token)
|
||||||
|
- Optimized for GPU parallel processing with configurable block size
|
||||||
|
"""
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
|
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
|
||||||
|
row_id = work_id // num_blocks
|
||||||
|
block_offset = (work_id % num_blocks) * BLOCK_SIZE
|
||||||
|
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
|
||||||
|
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
|
||||||
|
vocab_mask = offsets < vocab_size
|
||||||
|
packed_bitmask_mask = bitmask_offsets < bitmask_strides
|
||||||
|
packed_bitmask = tl.load(bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets, packed_bitmask_mask)
|
||||||
|
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
|
||||||
|
bitmask = bitmask.reshape(BLOCK_SIZE)
|
||||||
|
|
||||||
|
tl.store(logits_ptr + batch_id * logits_strides + offsets, -float("inf"), vocab_mask & bitmask)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_bitmask_inplace_triton(
|
||||||
|
logits: paddle.Tensor,
|
||||||
|
bitmask: paddle.Tensor,
|
||||||
|
vocab_size: Optional[int] = None,
|
||||||
|
indices: Optional[List[int]] = None,
|
||||||
|
):
|
||||||
|
"""Applies vocabulary mask to logits tensor using Triton GPU kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Input logits tensor of shape [batch_size, vocab_size]
|
||||||
|
bitmask: Compressed mask tensor (int32) where each bit represents a token
|
||||||
|
vocab_size: Optional explicit vocabulary size (defaults to auto-detected)
|
||||||
|
indices: Optional list of batch indices to apply mask to
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Requires CUDA GPU with Triton support
|
||||||
|
Bitmask must be int32 tensor with shape [batch_size, ceil(vocab_size/32)]
|
||||||
|
"""
|
||||||
|
NUM_SMS = paddle.device.cuda.get_device_properties().multi_processor_count
|
||||||
|
BLOCK_SIZE = 4096
|
||||||
|
|
||||||
|
assert bitmask.dtype == paddle.int32, "bitmask must be of type int32"
|
||||||
|
|
||||||
|
detected_vocab_size = min(logits.shape[-1], bitmask.shape[-1] * 32)
|
||||||
|
if vocab_size is None:
|
||||||
|
vocab_size = detected_vocab_size
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
vocab_size <= detected_vocab_size
|
||||||
|
), f"vocab_size {vocab_size} is larger than the detected vocab_size {detected_vocab_size}"
|
||||||
|
|
||||||
|
num_rows = len(indices) if indices is not None else logits.shape[0] if logits.ndim == 2 else 1
|
||||||
|
|
||||||
|
if indices is not None:
|
||||||
|
indices = paddle.to_tensor(indices, dtype=paddle.int32, place=logits.place)
|
||||||
|
|
||||||
|
grid = (NUM_SMS,)
|
||||||
|
|
||||||
|
apply_token_bitmask_inplace_kernel[grid](
|
||||||
|
logits,
|
||||||
|
bitmask,
|
||||||
|
indices,
|
||||||
|
num_rows,
|
||||||
|
vocab_size,
|
||||||
|
logits.shape[-1],
|
||||||
|
bitmask.shape[-1],
|
||||||
|
NUM_SMS,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
|
||||||
|
num_stages=3,
|
||||||
|
)
|
@@ -24,7 +24,7 @@ import torch
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request
|
from fastdeploy.engine.request import Request
|
||||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
from fastdeploy.model_executor.guided_decoding import (
|
||||||
BackendBase,
|
BackendBase,
|
||||||
BaseChecker,
|
BaseChecker,
|
||||||
LogitsProcessorBase,
|
LogitsProcessorBase,
|
||||||
@@ -57,7 +57,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
|||||||
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
max_rollback_tokens (int): Maximum number of tokens to rollback on mismatch
|
||||||
vocab_size (int): Size of the vocabulary
|
vocab_size (int): Size of the vocabulary
|
||||||
batch_size (int): Batch size for processing
|
batch_size (int): Batch size for processing
|
||||||
splitwise_role (str): Role for splitwise processing
|
|
||||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
override_stop_tokens (Optional[List[int]]): Custom stop tokens
|
||||||
@@ -71,13 +70,12 @@ class XGrammarProcessor(LogitsProcessorBase):
|
|||||||
override_stop_tokens: Optional[List[int]] = None,
|
override_stop_tokens: Optional[List[int]] = None,
|
||||||
vocab_size: Optional[int] = None,
|
vocab_size: Optional[int] = None,
|
||||||
batch_size: Optional[int] = None,
|
batch_size: Optional[int] = None,
|
||||||
splitwise_role: str = "mixed",
|
enable_thinking: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__(enable_reasoning=enable_thinking)
|
||||||
self.max_rollback_tokens = 200
|
self.max_rollback_tokens = 200
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.splitwise_role = splitwise_role
|
|
||||||
self.compiled_grammar = compiled_grammar
|
self.compiled_grammar = compiled_grammar
|
||||||
self.terminate_without_stop_token = terminate_without_stop_token
|
self.terminate_without_stop_token = terminate_without_stop_token
|
||||||
self.override_stop_tokens = override_stop_tokens
|
self.override_stop_tokens = override_stop_tokens
|
||||||
@@ -188,7 +186,6 @@ class XGrammarProcessor(LogitsProcessorBase):
|
|||||||
override_stop_tokens=self.override_stop_tokens,
|
override_stop_tokens=self.override_stop_tokens,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
splitwise_role=self.splitwise_role,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -203,7 +200,6 @@ class XGrammarBackend(BackendBase):
|
|||||||
vocab_size (int): Size of the vocabulary from config
|
vocab_size (int): Size of the vocabulary from config
|
||||||
batch_size (int): Maximum batch size from config
|
batch_size (int): Maximum batch size from config
|
||||||
any_whitespace (bool): Whether to allow any whitespace in JSON
|
any_whitespace (bool): Whether to allow any whitespace in JSON
|
||||||
splitwise_role (str): Role for splitwise processing
|
|
||||||
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
grammar_compiler (GrammarCompiler): Grammar compilation engine
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -217,7 +213,6 @@ class XGrammarBackend(BackendBase):
|
|||||||
self.batch_size = fd_config.parallel_config.max_num_seqs
|
self.batch_size = fd_config.parallel_config.max_num_seqs
|
||||||
|
|
||||||
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
self.any_whitespace = not fd_config.parallel_config.disable_any_whitespace
|
||||||
self.splitwise_role = fd_config.parallel_config.splitwise_role
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
tokenizer_info = TokenizerInfo.from_huggingface(self.hf_tokenizer, vocab_size=self.vocab_size)
|
||||||
@@ -230,6 +225,7 @@ class XGrammarBackend(BackendBase):
|
|||||||
compiled_grammar: CompiledGrammar,
|
compiled_grammar: CompiledGrammar,
|
||||||
terminate_without_stop_token: bool = False,
|
terminate_without_stop_token: bool = False,
|
||||||
override_stop_tokens: Optional[List[int]] = None,
|
override_stop_tokens: Optional[List[int]] = None,
|
||||||
|
enable_thinking: bool = False,
|
||||||
) -> XGrammarProcessor:
|
) -> XGrammarProcessor:
|
||||||
"""
|
"""
|
||||||
Create a logits processor instance for the given compiled grammar.
|
Create a logits processor instance for the given compiled grammar.
|
||||||
@@ -238,6 +234,7 @@ class XGrammarBackend(BackendBase):
|
|||||||
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
compiled_grammar (CompiledGrammar): Compiled grammar rules
|
||||||
terminate_without_stop_token (bool): Whether to terminate without stop token
|
terminate_without_stop_token (bool): Whether to terminate without stop token
|
||||||
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
override_stop_tokens (Optional[List[int]]): Custom stop tokens to override defaults
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
XGrammarProcessor: Configured grammar processor instance
|
XGrammarProcessor: Configured grammar processor instance
|
||||||
@@ -248,15 +245,16 @@ class XGrammarBackend(BackendBase):
|
|||||||
override_stop_tokens=override_stop_tokens,
|
override_stop_tokens=override_stop_tokens,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
splitwise_role=self.splitwise_role,
|
enable_thinking=enable_thinking,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _json_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
def _json_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||||
"""
|
"""
|
||||||
Compile JSON schema into a grammar processor.
|
Compile JSON schema into a grammar processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): JSON schema string to compile
|
schemata (str): JSON schema string to compile
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||||
@@ -266,14 +264,15 @@ class XGrammarBackend(BackendBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"Failed to compile json schema: {e}, {str(traceback.format_exc())}")
|
||||||
return None
|
return None
|
||||||
return self._create_processor(compiled_grammar)
|
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||||
|
|
||||||
def _regex_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
def _regex_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||||
"""
|
"""
|
||||||
Compile regex pattern into a grammar processor.
|
Compile regex pattern into a grammar processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): Regex pattern string to compile
|
schemata (str): Regex pattern string to compile
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||||
@@ -283,14 +282,15 @@ class XGrammarBackend(BackendBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"Failed to compile regex schema: {e}, {str(traceback.format_exc())}")
|
||||||
return None
|
return None
|
||||||
return self._create_processor(compiled_grammar)
|
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||||
|
|
||||||
def _grammar_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
def _grammar_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||||
"""
|
"""
|
||||||
Compile grammar (EBNF) into a grammar processor.
|
Compile grammar (EBNF) into a grammar processor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schemata (str): Grammar string in EBNF format
|
schemata (str): Grammar string in EBNF format
|
||||||
|
enable_thinking (bool): Whether to enable thinking mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
Optional[XGrammarProcessor]: Configured processor if successful, None on failure
|
||||||
@@ -300,9 +300,9 @@ class XGrammarBackend(BackendBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"Failed to compile ebnf schema: {e}, {str(traceback.format_exc())}")
|
||||||
return None
|
return None
|
||||||
return self._create_processor(compiled_grammar)
|
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||||
|
|
||||||
def _structural_tag_processor(self, schemata: str) -> Optional[XGrammarProcessor]:
|
def _structural_tag_processor(self, schemata: str, enable_thinking: bool = False) -> Optional[XGrammarProcessor]:
|
||||||
"""
|
"""
|
||||||
Compile structural tags into a grammar processor.
|
Compile structural tags into a grammar processor.
|
||||||
|
|
||||||
@@ -327,7 +327,7 @@ class XGrammarBackend(BackendBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
|
llm_logger.error(f"Failed to compile structural tags schema: {e}, {str(traceback.format_exc())}")
|
||||||
return None
|
return None
|
||||||
return self._create_processor(compiled_grammar)
|
return self._create_processor(compiled_grammar, enable_thinking=enable_thinking)
|
||||||
|
|
||||||
|
|
||||||
class XGrammarChecker(BaseChecker):
|
class XGrammarChecker(BaseChecker):
|
||||||
|
@@ -23,9 +23,7 @@ import paddle.nn.functional as F
|
|||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
|
||||||
LogitsProcessorBase,
|
|
||||||
)
|
|
||||||
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
from fastdeploy.model_executor.layers.sample.early_stopper import (
|
||||||
get_early_stopper_cls_from_stragegy,
|
get_early_stopper_cls_from_stragegy,
|
||||||
)
|
)
|
||||||
@@ -37,6 +35,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
|
|||||||
top_k_top_p_sampling,
|
top_k_top_p_sampling,
|
||||||
)
|
)
|
||||||
from fastdeploy.platforms import current_platform
|
from fastdeploy.platforms import current_platform
|
||||||
|
from fastdeploy.reasoning import ReasoningParser
|
||||||
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +62,10 @@ class SamplerProcessor:
|
|||||||
self.logits_processor: Dict[int, Optional[Any]] = dict()
|
self.logits_processor: Dict[int, Optional[Any]] = dict()
|
||||||
self.executor = ThreadPoolExecutor()
|
self.executor = ThreadPoolExecutor()
|
||||||
self.logits_lock = threading.Lock()
|
self.logits_lock = threading.Lock()
|
||||||
|
self.reasoning_parser = None
|
||||||
|
|
||||||
|
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||||
|
self.reasoning_parser = reasoning_parser
|
||||||
|
|
||||||
def add_logits_processor(
|
def add_logits_processor(
|
||||||
self,
|
self,
|
||||||
@@ -139,9 +142,14 @@ class SamplerProcessor:
|
|||||||
if available_processors is None:
|
if available_processors is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
indices = list(self.logits_processor.keys())
|
indices = []
|
||||||
mask_idx = [i for i in indices if i not in skip_idx_list]
|
for idx, processor in self.logits_processor.items():
|
||||||
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx)
|
if processor is None or idx in skip_idx_list:
|
||||||
|
continue
|
||||||
|
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
|
||||||
|
indices.append(idx)
|
||||||
|
|
||||||
|
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
|
||||||
|
|
||||||
def _accept_token(self, idx: int, token: int):
|
def _accept_token(self, idx: int, token: int):
|
||||||
"""accept token"""
|
"""accept token"""
|
||||||
@@ -151,6 +159,15 @@ class SamplerProcessor:
|
|||||||
if self.logits_processor[idx].is_terminated():
|
if self.logits_processor[idx].is_terminated():
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.reasoning_parser is not None
|
||||||
|
and self.logits_processor[idx].enable_reasoning
|
||||||
|
and not self.logits_processor[idx].reasoning_ended
|
||||||
|
):
|
||||||
|
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
|
||||||
|
self.logits_processor[idx].reasoning_ended = reasoning_ended
|
||||||
|
return
|
||||||
|
|
||||||
self.logits_processor[idx].accept_token(token)
|
self.logits_processor[idx].accept_token(token)
|
||||||
|
|
||||||
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||||
@@ -206,12 +223,11 @@ class Sampler(nn.Layer):
|
|||||||
self.early_stopper = early_stopper_cls()
|
self.early_stopper = early_stopper_cls()
|
||||||
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
|
||||||
|
|
||||||
def apply_logits_processor(
|
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||||
self,
|
"""set reasoning parser"""
|
||||||
ids: int,
|
self.processor.apply_reasoning_parser(reasoning_parser)
|
||||||
future: Optional[Any] = None,
|
|
||||||
prefill_tokens: List[int] = [],
|
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||||
):
|
|
||||||
"""apply logits processor to sampler"""
|
"""apply logits processor to sampler"""
|
||||||
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
||||||
|
|
||||||
@@ -219,6 +235,10 @@ class Sampler(nn.Layer):
|
|||||||
"""pre process before running"""
|
"""pre process before running"""
|
||||||
self.processor.pre_process(skip_idx_list)
|
self.processor.pre_process(skip_idx_list)
|
||||||
|
|
||||||
|
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||||
|
"""post process after running"""
|
||||||
|
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||||
|
|
||||||
def compute_logprobs(
|
def compute_logprobs(
|
||||||
self,
|
self,
|
||||||
logits: paddle.Tensor,
|
logits: paddle.Tensor,
|
||||||
@@ -307,12 +327,12 @@ class Sampler(nn.Layer):
|
|||||||
skip_idx_list: List[int] = [],
|
skip_idx_list: List[int] = [],
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
""" """
|
""" """
|
||||||
|
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||||
|
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
|
||||||
|
|
||||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
|
||||||
|
|
||||||
logits = apply_penalty_multi_scores(
|
logits = apply_penalty_multi_scores(
|
||||||
sampling_metadata.pre_token_ids,
|
sampling_metadata.pre_token_ids,
|
||||||
sampling_metadata.prompt_ids,
|
sampling_metadata.prompt_ids,
|
||||||
@@ -347,8 +367,6 @@ class Sampler(nn.Layer):
|
|||||||
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
|
||||||
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
|
||||||
|
|
||||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
|
||||||
|
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
# The sampled tokens are expanded to 2D tensor with shape
|
# The sampled tokens are expanded to 2D tensor with shape
|
||||||
# [num_requests, 1], where each row represents one generated
|
# [num_requests, 1], where each row represents one generated
|
||||||
@@ -380,12 +398,15 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
"""pre process before running"""
|
"""pre process before running"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def apply_logits_processor(
|
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||||
self,
|
"""set reasoning parser"""
|
||||||
ids: int,
|
pass
|
||||||
future: Optional[Any] = None,
|
|
||||||
prefill_tokens: List[int] = [],
|
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||||
):
|
"""post process after running"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||||
"""apply logits processor to sampler"""
|
"""apply logits processor to sampler"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -480,6 +501,14 @@ class MTPSampler(nn.Layer):
|
|||||||
"""apply logits processor to sampler"""
|
"""apply logits processor to sampler"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||||
|
"""set reasoning parser"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||||
|
"""post process after running"""
|
||||||
|
pass
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
logits: paddle.Tensor,
|
logits: paddle.Tensor,
|
||||||
|
@@ -29,9 +29,9 @@ from fastdeploy.model_executor.graph_optimization.utils import (
|
|||||||
profile_run_guard,
|
profile_run_guard,
|
||||||
sot_warmup_guard,
|
sot_warmup_guard,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.guided_decoding import get_guided_backend
|
from fastdeploy.model_executor.guided_decoding import (
|
||||||
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
|
|
||||||
LogitsProcessorBase,
|
LogitsProcessorBase,
|
||||||
|
get_guided_backend,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||||
@@ -97,10 +97,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.enable_logprob = fd_config.model_config.enable_logprob
|
self.enable_logprob = fd_config.model_config.enable_logprob
|
||||||
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
|
||||||
|
|
||||||
self.guided_backend = None
|
|
||||||
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
|
||||||
self.guided_backend = get_guided_backend(fd_config=self.fd_config)
|
|
||||||
|
|
||||||
# VL model config:
|
# VL model config:
|
||||||
if self.enable_mm:
|
if self.enable_mm:
|
||||||
if "ernie" in self.fd_config.model_config.model_type:
|
if "ernie" in self.fd_config.model_config.model_type:
|
||||||
@@ -129,6 +125,11 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
else:
|
else:
|
||||||
self.sampler = SpeculativeSampler(fd_config)
|
self.sampler = SpeculativeSampler(fd_config)
|
||||||
|
|
||||||
|
self.guided_backend = None
|
||||||
|
if self.fd_config.parallel_config.guided_decoding_backend != "off":
|
||||||
|
self.guided_backend = get_guided_backend(fd_config=self.fd_config)
|
||||||
|
self.sampler.set_reasoning_parser(self.guided_backend.get_reasoning_parser())
|
||||||
|
|
||||||
# Lazy initialize kv cache after model loading
|
# Lazy initialize kv cache after model loading
|
||||||
# self.kv_caches: list[paddle.Tensor] = []
|
# self.kv_caches: list[paddle.Tensor] = []
|
||||||
|
|
||||||
@@ -206,7 +207,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
elif request.structural_tag is not None:
|
elif request.structural_tag is not None:
|
||||||
schemata_key = ("structural_tag", request.structural_tag)
|
schemata_key = ("structural_tag", request.structural_tag)
|
||||||
|
|
||||||
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
|
enable_thinking = request.get("enable_thinking", True)
|
||||||
|
enable_thinking = enable_thinking if enable_thinking is not None else True
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.guided_backend.get_logits_processor(
|
||||||
|
schemata_key=schemata_key,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
|
),
|
||||||
|
schemata_key,
|
||||||
|
)
|
||||||
|
|
||||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
|
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
|
||||||
"""
|
"""
|
||||||
@@ -1336,10 +1346,10 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of indices corresponding to the requests that need to be skipped.
|
A list of indices corresponding to the requests that need to be skipped.
|
||||||
"""
|
"""
|
||||||
skip_idx_list = []
|
if not self.cache_config.enable_chunked_prefill or self.guided_backend is None or model_forward_batch is None:
|
||||||
if not self.cache_config.enable_chunked_prefill or self.guided_backend is None:
|
return []
|
||||||
return skip_idx_list
|
|
||||||
|
|
||||||
|
skip_idx_list = []
|
||||||
for task in model_forward_batch:
|
for task in model_forward_batch:
|
||||||
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
if task.get("prefill_chunk_info", None) is None or task.chunk_idx >= len(task.prefill_chunk_info):
|
||||||
continue
|
continue
|
||||||
@@ -1505,6 +1515,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
speculative_decoding=self.speculative_decoding,
|
speculative_decoding=self.speculative_decoding,
|
||||||
skip_save_output=skip_save_output,
|
skip_save_output=skip_save_output,
|
||||||
)
|
)
|
||||||
|
if self.guided_backend is not None and sampler_output is not None:
|
||||||
|
self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list)
|
||||||
|
|
||||||
# 6. Speculative decode
|
# 6. Speculative decode
|
||||||
if self.speculative_decoding:
|
if self.speculative_decoding:
|
||||||
@@ -1538,7 +1550,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
"""
|
"""
|
||||||
Add cache for guided decoding.
|
Add cache for guided decoding.
|
||||||
"""
|
"""
|
||||||
if self.guided_backend is None:
|
if self.guided_backend is None or model_forward_batch is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for request in model_forward_batch:
|
for request in model_forward_batch:
|
||||||
|
@@ -590,6 +590,12 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable output of token-level log probabilities.",
|
help="Enable output of token-level log probabilities.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reasoning_parser",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Flag specifies the reasoning parser to use for extracting reasoning content from the model output",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--early_stop_config",
|
"--early_stop_config",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
|
@@ -7,6 +7,7 @@ python -m pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/p
|
|||||||
|
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
python -m pip install jsonschema aistudio_sdk==0.3.5
|
python -m pip install jsonschema aistudio_sdk==0.3.5
|
||||||
|
python -m pip install xgrammar==0.1.19 torch==2.6.0
|
||||||
|
|
||||||
failed_files=()
|
failed_files=()
|
||||||
run_path="$DIR/../tests/ci_use/"
|
run_path="$DIR/../tests/ci_use/"
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -110,6 +111,8 @@ def setup_and_run_server():
|
|||||||
"--use-cudagraph",
|
"--use-cudagraph",
|
||||||
"--graph-optimization-config",
|
"--graph-optimization-config",
|
||||||
'{"cudagraph_capture_sizes": [1]}',
|
'{"cudagraph_capture_sizes": [1]}',
|
||||||
|
"--guided-decoding-backend",
|
||||||
|
"auto",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Start subprocess in new process group
|
# Start subprocess in new process group
|
||||||
@@ -1142,3 +1145,336 @@ def test_profile_reset_block_num():
|
|||||||
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
f"Reset total_block_num {actual_value} 与 baseline {baseline} diff需要在5%以内"
|
||||||
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
f"Allowed range: [{lower_bound:.1f}, {upper_bound:.1f}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_chat_base(openai_client, chat_param):
|
||||||
|
"""
|
||||||
|
Test streaming chat base functionality with the local service
|
||||||
|
"""
|
||||||
|
assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
|
||||||
|
assert "messages" in chat_param, f"{chat_param} should contain messages"
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=True,
|
||||||
|
**chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for chunk in response:
|
||||||
|
if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
||||||
|
output.append(chunk.choices[0].delta.content)
|
||||||
|
assert len(output) > 2
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def non_streaming_chat_base(openai_client, chat_param):
|
||||||
|
"""
|
||||||
|
Test non streaming chat base functionality with the local service
|
||||||
|
"""
|
||||||
|
assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
|
||||||
|
assert "messages" in chat_param, f"{chat_param} should contain messages"
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=False,
|
||||||
|
**chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(response, "choices")
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert hasattr(response.choices[0], "message")
|
||||||
|
assert hasattr(response.choices[0].message, "content")
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_json_schema(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs json_schema functionality with the local service
|
||||||
|
"""
|
||||||
|
chat_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
# json_object
|
||||||
|
json_chat_param = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate a JSON object containing: names of China's Four Great Inventions, their dynasties of origin, and brief descriptions (each under 50 characters)",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"response_format": {"type": "json_object"},
|
||||||
|
}
|
||||||
|
json_chat_param.update(chat_param)
|
||||||
|
|
||||||
|
response = streaming_chat_base(openai_client, json_chat_param)
|
||||||
|
try:
|
||||||
|
json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema streaming response: {response} is not a valid json"
|
||||||
|
|
||||||
|
response = non_streaming_chat_base(openai_client, json_chat_param)
|
||||||
|
try:
|
||||||
|
json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
|
||||||
|
|
||||||
|
# json_schema
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class BookType(str, Enum):
|
||||||
|
romance = "Romance"
|
||||||
|
historical = "Historical"
|
||||||
|
adventure = "Adventure"
|
||||||
|
mystery = "Mystery"
|
||||||
|
dystopian = "Dystopian"
|
||||||
|
|
||||||
|
class BookDescription(BaseModel):
|
||||||
|
author: str
|
||||||
|
title: str
|
||||||
|
genre: BookType
|
||||||
|
|
||||||
|
json_schema_param = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate a JSON describing a literary work, including author, title and book type.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json_schema_param.update(chat_param)
|
||||||
|
response = streaming_chat_base(openai_client, json_schema_param)
|
||||||
|
try:
|
||||||
|
json_schema_response = json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema streaming response: {response} is not a valid json"
|
||||||
|
assert (
|
||||||
|
"author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
|
||||||
|
), f"json_schema streaming response: {response} is not a valid book-description"
|
||||||
|
assert json_schema_response["genre"] in {
|
||||||
|
genre.value for genre in BookType
|
||||||
|
}, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type"
|
||||||
|
|
||||||
|
response = non_streaming_chat_base(openai_client, json_schema_param)
|
||||||
|
try:
|
||||||
|
json_schema_response = json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
|
||||||
|
assert (
|
||||||
|
"author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
|
||||||
|
), f"json_schema non_streaming response: {response} is not a valid book-description"
|
||||||
|
assert json_schema_response["genre"] in {
|
||||||
|
genre.value for genre in BookType
|
||||||
|
}, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_structural_tag(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs structural_tag functionality with the local service
|
||||||
|
"""
|
||||||
|
content_str = """
|
||||||
|
You have the following function available:
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "get_current_date",
|
||||||
|
"description": "Get current date and time for given timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["timezone"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
If you choose to call only this function, reply in this format:
|
||||||
|
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||||
|
where:
|
||||||
|
|
||||||
|
start_tag => `<function`
|
||||||
|
parameters => JSON dictionary with parameter names as keys
|
||||||
|
end_tag => `</function>`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<function=example_function>{"param": "value"}</function>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Function call must follow specified format
|
||||||
|
- Required parameters must be specified
|
||||||
|
- Only one function can be called at a time
|
||||||
|
- Place entire function call response on a single line
|
||||||
|
|
||||||
|
You are an AI assistant. Answer the following question.
|
||||||
|
"""
|
||||||
|
|
||||||
|
structural_tag_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": content_str,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "You're traveling to Shanghai today",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"response_format": {
|
||||||
|
"type": "structural_tag",
|
||||||
|
"structures": [
|
||||||
|
{
|
||||||
|
"begin": "<function=get_current_date>",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["timezone"],
|
||||||
|
},
|
||||||
|
"end": "</function>",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"triggers": ["<function="],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect_str1 = "get_current_date"
|
||||||
|
expect_str2 = "Asia/Shanghai"
|
||||||
|
response = streaming_chat_base(openai_client, structural_tag_param)
|
||||||
|
assert expect_str1 in response, f"structural_tag streaming response: {response} is not as expected"
|
||||||
|
assert expect_str2 in response, f"structural_tag streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
response = non_streaming_chat_base(openai_client, structural_tag_param)
|
||||||
|
assert expect_str1 in response, f"structural_tag non_streaming response: {response} is not as expected"
|
||||||
|
assert expect_str2 in response, f"structural_tag non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_choice(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs choice functionality with the local service
|
||||||
|
"""
|
||||||
|
choice_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [{"role": "user", "content": "What is the landmark building in Shenzhen?"}],
|
||||||
|
"extra_body": {
|
||||||
|
"guided_choice": ["Ping An Finance Centre", "China Resources Headquarters", "KK100", "Diwang Mansion"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = streaming_chat_base(openai_client, choice_param)
|
||||||
|
assert response in [
|
||||||
|
"Ping An Finance Centre",
|
||||||
|
"China Resources Headquarters",
|
||||||
|
"KK100",
|
||||||
|
"Diwang Mansion",
|
||||||
|
], f"choice streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, choice_param)
|
||||||
|
assert response in [
|
||||||
|
"Ping An Finance Centre",
|
||||||
|
"China Resources Headquarters",
|
||||||
|
"KK100",
|
||||||
|
"Diwang Mansion",
|
||||||
|
], f"choice non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_regex(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs regex functionality with the local service
|
||||||
|
"""
|
||||||
|
regex_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate a standard format web address including protocol and domain.\n",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"extra_body": {"guided_regex": r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n"},
|
||||||
|
}
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
response = streaming_chat_base(openai_client, regex_param)
|
||||||
|
assert re.fullmatch(
|
||||||
|
r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n", response
|
||||||
|
), f"regex streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, regex_param)
|
||||||
|
assert re.fullmatch(
|
||||||
|
r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n", response
|
||||||
|
), f"regex non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_grammar(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs grammar functionality with the local service
|
||||||
|
"""
|
||||||
|
html_h1_grammar = """
|
||||||
|
root ::= html_statement
|
||||||
|
|
||||||
|
html_statement ::= "<h1" style_attribute? ">" text "</h1>"
|
||||||
|
|
||||||
|
style_attribute ::= " style=" dq style_value dq
|
||||||
|
|
||||||
|
style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?)
|
||||||
|
|
||||||
|
font_style ::= "font-family: '" font_name "'"
|
||||||
|
|
||||||
|
font_weight ::= "font-weight: " weight_value
|
||||||
|
|
||||||
|
font_name ::= "Arial" | "Times New Roman" | "Courier New"
|
||||||
|
|
||||||
|
weight_value ::= "normal" | "bold"
|
||||||
|
|
||||||
|
text ::= [A-Za-z0-9 ]+
|
||||||
|
|
||||||
|
dq ::= ["]
|
||||||
|
"""
|
||||||
|
|
||||||
|
grammar_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"extra_body": {"guided_grammar": html_h1_grammar},
|
||||||
|
}
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r'^<h1( style="[^"]*")?>[A-Za-z0-9 ]+</h1>$'
|
||||||
|
response = streaming_chat_base(openai_client, grammar_param)
|
||||||
|
assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, grammar_param)
|
||||||
|
assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected"
|
||||||
|
@@ -119,6 +119,8 @@ def setup_and_run_server():
|
|||||||
"wint4",
|
"wint4",
|
||||||
"--reasoning-parser",
|
"--reasoning-parser",
|
||||||
"ernie-45-vl",
|
"ernie-45-vl",
|
||||||
|
"--guided-decoding-backend",
|
||||||
|
"auto",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Start subprocess in new process group
|
# Start subprocess in new process group
|
||||||
@@ -540,6 +542,348 @@ def test_chat_with_thinking(openai_client, capsys):
|
|||||||
assert reasoning_tokens <= reasoning_max_tokens
|
assert reasoning_tokens <= reasoning_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_chat_base(openai_client, chat_param):
|
||||||
|
"""
|
||||||
|
Test streaming chat base functionality with the local service
|
||||||
|
"""
|
||||||
|
assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
|
||||||
|
assert "messages" in chat_param, f"{chat_param} should contain messages"
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=True,
|
||||||
|
**chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for chunk in response:
|
||||||
|
if hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
||||||
|
output.append(chunk.choices[0].delta.content)
|
||||||
|
assert len(output) > 2
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def non_streaming_chat_base(openai_client, chat_param):
|
||||||
|
"""
|
||||||
|
Test non streaming chat base functionality with the local service
|
||||||
|
"""
|
||||||
|
assert isinstance(chat_param, dict), f"{chat_param} should be a dict"
|
||||||
|
assert "messages" in chat_param, f"{chat_param} should contain messages"
|
||||||
|
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
stream=False,
|
||||||
|
**chat_param,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(response, "choices")
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert hasattr(response.choices[0], "message")
|
||||||
|
assert hasattr(response.choices[0].message, "content")
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_json_schema(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs json_schema functionality with the local service
|
||||||
|
"""
|
||||||
|
chat_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
# json_object
|
||||||
|
json_chat_param = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://paddlenlp.bj.bcebos.com/datasets/paddlemix/demo_images/example2.jpg",
|
||||||
|
"detail": "high",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "请描述图片内容,使用json格式输出结果"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"response_format": {"type": "json_object"},
|
||||||
|
}
|
||||||
|
json_chat_param.update(chat_param)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
outputs.append(streaming_chat_base(openai_client, json_chat_param))
|
||||||
|
outputs.append(non_streaming_chat_base(openai_client, json_chat_param))
|
||||||
|
|
||||||
|
json_chat_param["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
|
||||||
|
outputs.append(streaming_chat_base(openai_client, json_chat_param))
|
||||||
|
outputs.append(non_streaming_chat_base(openai_client, json_chat_param))
|
||||||
|
|
||||||
|
for response in outputs:
|
||||||
|
try:
|
||||||
|
json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_object response: {response} is not a valid json"
|
||||||
|
|
||||||
|
# json_schema
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class BookType(str, Enum):
|
||||||
|
romance = "Romance"
|
||||||
|
historical = "Historical"
|
||||||
|
adventure = "Adventure"
|
||||||
|
mystery = "Mystery"
|
||||||
|
dystopian = "Dystopian"
|
||||||
|
|
||||||
|
class BookDescription(BaseModel):
|
||||||
|
author: str
|
||||||
|
title: str
|
||||||
|
genre: BookType
|
||||||
|
|
||||||
|
json_schema_param = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate a JSON describing a literary work, including author, title and book type.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"response_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {"name": "book-description", "schema": BookDescription.model_json_schema()},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json_schema_param.update(chat_param)
|
||||||
|
response = streaming_chat_base(openai_client, json_schema_param)
|
||||||
|
try:
|
||||||
|
json_schema_response = json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema streaming response: {response} is not a valid json"
|
||||||
|
assert (
|
||||||
|
"author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
|
||||||
|
), f"json_schema streaming response: {response} is not a valid book-description"
|
||||||
|
assert json_schema_response["genre"] in {
|
||||||
|
genre.value for genre in BookType
|
||||||
|
}, f"json_schema streaming response: {json_schema_response['genre']} is not a valid book-type"
|
||||||
|
|
||||||
|
response = non_streaming_chat_base(openai_client, json_schema_param)
|
||||||
|
try:
|
||||||
|
json_schema_response = json.loads(response)
|
||||||
|
is_valid = True
|
||||||
|
except ValueError:
|
||||||
|
is_valid = False
|
||||||
|
|
||||||
|
assert is_valid, f"json_schema non_streaming response: {response} is not a valid json"
|
||||||
|
assert (
|
||||||
|
"author" in json_schema_response and "title" in json_schema_response and "genre" in json_schema_response
|
||||||
|
), f"json_schema non_streaming response: {response} is not a valid book-description"
|
||||||
|
assert json_schema_response["genre"] in {
|
||||||
|
genre.value for genre in BookType
|
||||||
|
}, f"json_schema non_streaming response: {json_schema_response['genre']} is not a valid book-type"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_structural_tag(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs structural_tag functionality with the local service
|
||||||
|
"""
|
||||||
|
content_str = """
|
||||||
|
You have the following function available:
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "get_current_date",
|
||||||
|
"description": "Get current date and time for given timezone",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["timezone"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
If you choose to call only this function, reply in this format:
|
||||||
|
<{start_tag}={function_name}>{parameters}{end_tag}
|
||||||
|
where:
|
||||||
|
|
||||||
|
start_tag => `<function`
|
||||||
|
parameters => JSON dictionary with parameter names as keys
|
||||||
|
end_tag => `</function>`
|
||||||
|
|
||||||
|
Example:
|
||||||
|
<function=example_function>{"param": "value"}</function>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Function call must follow specified format
|
||||||
|
- Required parameters must be specified
|
||||||
|
- Only one function can be called at a time
|
||||||
|
- Place entire function call response on a single line
|
||||||
|
|
||||||
|
You are an AI assistant. Answer the following question.
|
||||||
|
"""
|
||||||
|
|
||||||
|
structural_tag_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": content_str,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "You're traveling to Shanghai today",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"response_format": {
|
||||||
|
"type": "structural_tag",
|
||||||
|
"structures": [
|
||||||
|
{
|
||||||
|
"begin": "<function=get_current_date>",
|
||||||
|
"schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"timezone": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Timezone to get current date/time, e.g.: Asia/Shanghai",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["timezone"],
|
||||||
|
},
|
||||||
|
"end": "</function>",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"triggers": ["<function="],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expect_str1 = "get_current_date"
|
||||||
|
expect_str2 = "Asia/Shanghai"
|
||||||
|
response = streaming_chat_base(openai_client, structural_tag_param)
|
||||||
|
assert expect_str1 in response, f"structural_tag streaming response: {response} is not as expected"
|
||||||
|
assert expect_str2 in response, f"structural_tag streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
response = non_streaming_chat_base(openai_client, structural_tag_param)
|
||||||
|
assert expect_str1 in response, f"structural_tag non_streaming response: {response} is not as expected"
|
||||||
|
assert expect_str2 in response, f"structural_tag non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_choice(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs choice functionality with the local service
|
||||||
|
"""
|
||||||
|
choice_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [{"role": "user", "content": "What is the landmark building in Shenzhen?"}],
|
||||||
|
"extra_body": {
|
||||||
|
"guided_choice": ["Ping An Finance Centre", "China Resources Headquarters", "KK100", "Diwang Mansion"]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = streaming_chat_base(openai_client, choice_param)
|
||||||
|
assert response in [
|
||||||
|
"Ping An Finance Centre",
|
||||||
|
"China Resources Headquarters",
|
||||||
|
"KK100",
|
||||||
|
"Diwang Mansion",
|
||||||
|
], f"choice streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, choice_param)
|
||||||
|
assert response in [
|
||||||
|
"Ping An Finance Centre",
|
||||||
|
"China Resources Headquarters",
|
||||||
|
"KK100",
|
||||||
|
"Diwang Mansion",
|
||||||
|
], f"choice non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_regex(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs regex functionality with the local service
|
||||||
|
"""
|
||||||
|
regex_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate a standard format web address including protocol and domain.\n",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"extra_body": {"guided_regex": r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n"},
|
||||||
|
}
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
response = streaming_chat_base(openai_client, regex_param)
|
||||||
|
assert re.fullmatch(
|
||||||
|
r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n", response
|
||||||
|
), f"regex streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, regex_param)
|
||||||
|
assert re.fullmatch(
|
||||||
|
r"^https:\/\/www\.[a-zA-Z]+\.com\/?$\n", response
|
||||||
|
), f"regex non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_outputs_grammar(openai_client):
|
||||||
|
"""
|
||||||
|
Test structured outputs grammar functionality with the local service
|
||||||
|
"""
|
||||||
|
html_h1_grammar = """
|
||||||
|
root ::= html_statement
|
||||||
|
|
||||||
|
html_statement ::= "<h1" style_attribute? ">" text "</h1>"
|
||||||
|
|
||||||
|
style_attribute ::= " style=" dq style_value dq
|
||||||
|
|
||||||
|
style_value ::= (font_style ("; " font_weight)?) | (font_weight ("; " font_style)?)
|
||||||
|
|
||||||
|
font_style ::= "font-family: '" font_name "'"
|
||||||
|
|
||||||
|
font_weight ::= "font-weight: " weight_value
|
||||||
|
|
||||||
|
font_name ::= "Arial" | "Times New Roman" | "Courier New"
|
||||||
|
|
||||||
|
weight_value ::= "normal" | "bold"
|
||||||
|
|
||||||
|
text ::= [A-Za-z0-9 ]+
|
||||||
|
|
||||||
|
dq ::= ["]
|
||||||
|
"""
|
||||||
|
|
||||||
|
grammar_param = {
|
||||||
|
"temperature": 1,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate HTML code for this heading in bold Times New Roman font: ERNIE Bot",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"extra_body": {"guided_grammar": html_h1_grammar},
|
||||||
|
}
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r'^<h1( style="[^"]*")?>[A-Za-z0-9 ]+</h1>$'
|
||||||
|
response = streaming_chat_base(openai_client, grammar_param)
|
||||||
|
assert re.fullmatch(pattern, response), f"grammar streaming response: {response} is not as expected"
|
||||||
|
response = non_streaming_chat_base(openai_client, grammar_param)
|
||||||
|
assert re.fullmatch(pattern, response), f"grammar non_streaming response: {response} is not as expected"
|
||||||
|
|
||||||
|
|
||||||
def test_profile_reset_block_num():
|
def test_profile_reset_block_num():
|
||||||
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
|
"""测试profile reset_block_num功能,与baseline diff不能超过5%"""
|
||||||
log_file = "./log/config.log"
|
log_file = "./log/config.log"
|
||||||
|
Reference in New Issue
Block a user