[Feature] mm and thinking model support structred output (#2749)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* mm support structured output

* update code

* update code

* update format

* update code

* update code

* add enable_thinking default

* update code

* add structured_outputs test case

* add ci install xgrammar

* add ci timeout time

* update test for structured_outputs

* update code

* add error traceback info

* update error msg

* update structred output code

* update code

* update code

* update config

* update torch version

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
kevin
2025-09-02 16:21:09 +08:00
committed by GitHub
parent 0e4df5a6f4
commit 1908465542
17 changed files with 1168 additions and 83 deletions

View File

@@ -100,6 +100,7 @@ class SamplingParams:
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
@classmethod
@@ -132,6 +133,7 @@ class SamplingParams:
min_tokens=1,
logprobs=None,
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
@@ -153,6 +155,7 @@ class SamplingParams:
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
)
@@ -217,3 +220,51 @@ class BeamSearchParams:
temperature: float = 0.0
length_penalty: float = 1.0
include_stop_str_in_output: bool = False
@dataclass
class GuidedDecodingParams:
"""Guided decoding parameters for text generation."""
json: Optional[Union[str, dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
structural_tag: Optional[str] = None
def to_dict(self):
"""convert to dict"""
key_dict = {
"guided_json": self.json,
"guided_regex": self.regex,
"guided_choice": self.choice,
"guided_grammar": self.grammar,
"structural_tag": self.structural_tag,
"guided_json_object": self.json_object,
}
guided_dict = {}
for key, value in key_dict.items():
if value is not None:
guided_dict[key] = value
return guided_dict
def __post_init__(self):
"""Verify the arguments."""
guided_count = sum(
[
self.json is not None,
self.regex is not None,
self.choice is not None,
self.grammar is not None,
self.json_object is not None,
self.structural_tag is not None,
]
)
if guided_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('json', 'json_object', 'regex', 'choice', 'grammar', 'structural_tag')."
)