[Feature] mm and thinking model support structred output (#2749)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* mm support structured output

* update code

* update code

* update format

* update code

* update code

* add enable_thinking default

* update code

* add structured_outputs test case

* add ci install xgrammar

* add ci timeout time

* update test for structured_outputs

* update code

* add error traceback info

* update error msg

* update structred output code

* update code

* update code

* update config

* update torch version

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
kevin
2025-09-02 16:21:09 +08:00
committed by GitHub
parent 0e4df5a6f4
commit 1908465542
17 changed files with 1168 additions and 83 deletions

View File

@@ -23,9 +23,7 @@ import paddle.nn.functional as F
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.guided_decoding.base_guided_decoding import (
LogitsProcessorBase,
)
from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase
from fastdeploy.model_executor.layers.sample.early_stopper import (
get_early_stopper_cls_from_stragegy,
)
@@ -37,6 +35,7 @@ from fastdeploy.model_executor.layers.sample.ops import (
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
from fastdeploy.reasoning import ReasoningParser
from fastdeploy.worker.output import LogprobsTensors, SamplerOutput
@@ -63,6 +62,10 @@ class SamplerProcessor:
self.logits_processor: Dict[int, Optional[Any]] = dict()
self.executor = ThreadPoolExecutor()
self.logits_lock = threading.Lock()
self.reasoning_parser = None
def apply_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
self.reasoning_parser = reasoning_parser
def add_logits_processor(
self,
@@ -139,9 +142,14 @@ class SamplerProcessor:
if available_processors is None:
return logits
indices = list(self.logits_processor.keys())
mask_idx = [i for i in indices if i not in skip_idx_list]
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=mask_idx)
indices = []
for idx, processor in self.logits_processor.items():
if processor is None or idx in skip_idx_list:
continue
if self.reasoning_parser is None or not processor.enable_reasoning or processor.reasoning_ended:
indices.append(idx)
return available_processors.apply_token_mask(logits, self.token_bitmask, indices=indices)
def _accept_token(self, idx: int, token: int):
"""accept token"""
@@ -151,6 +159,15 @@ class SamplerProcessor:
if self.logits_processor[idx].is_terminated():
return
if (
self.reasoning_parser is not None
and self.logits_processor[idx].enable_reasoning
and not self.logits_processor[idx].reasoning_ended
):
reasoning_ended = self.reasoning_parser.is_reasoning_end([token])
self.logits_processor[idx].reasoning_ended = reasoning_ended
return
self.logits_processor[idx].accept_token(token)
def update_output_tokens(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
@@ -206,12 +223,11 @@ class Sampler(nn.Layer):
self.early_stopper = early_stopper_cls()
self.early_stopper.initialize(fd_config.parallel_config.max_num_seqs, fd_config.early_stop_config)
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.processor.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
@@ -219,6 +235,10 @@ class Sampler(nn.Layer):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
self.processor.update_output_tokens(next_tokens, skip_idx_list)
def compute_logprobs(
self,
logits: paddle.Tensor,
@@ -307,12 +327,12 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.processor.apply_token_mask(logits, skip_idx_list)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits, sampling_metadata)
logits = self.processor.apply_token_mask(logits, skip_idx_list)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,
@@ -347,8 +367,6 @@ class Sampler(nn.Layer):
assert sampling_metadata.stop_flags is not None, "need stop_flags for eary stop"
self.early_stopper.process(probs, next_tokens, sampling_metadata.stop_flags)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
@@ -380,12 +398,15 @@ class SpeculativeSampler(nn.Layer):
"""pre process before running"""
pass
def apply_logits_processor(
self,
ids: int,
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
pass
@@ -480,6 +501,14 @@ class MTPSampler(nn.Layer):
"""apply logits processor to sampler"""
pass
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
pass
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
pass
def forward_cuda(
self,
logits: paddle.Tensor,