[Feature] support logits processors (#4515)

* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor

* [chore] fix code style

* [fix] add unit test & fix existing bugs

* [feat] add engine/worker arg --logits-processors

* [fix] redefine user args as logits_processors_args and fix some bugs

* [fix] fix test_sampler

* Update fastdeploy/model_executor/logits_processor/builtin.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/logits_processor/__init__.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/model_executor/test_logits_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix typo

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix bracelet

* [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits

* [doc] add docs

* [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests

* [fix] fix redundant code

* [feat] skip apply() if no bias is specified

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
李泳桦
2025-10-29 00:08:53 +08:00
committed by GitHub
parent 24b9505971
commit a012e3608b
18 changed files with 882 additions and 14 deletions

View File

@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
class SamplerProcessor:
class GuidedDecoding:
"""
SamplingProcessor for guided decoding.
processor for guided decoding.
"""
def __init__(self):
@@ -75,7 +75,7 @@ class SamplerProcessor:
future: Optional[Any] = None,
prefill_tokens: List[int] = [],
):
"""add logits processor to SamplerProcessor"""
"""add logits processor to GuidedDecoding"""
with self.logits_lock:
if future is None:
if ids in self.logits_processor:
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
else:
raise NotImplementedError
self.processor = SamplerProcessor()
self.guided_decoding = GuidedDecoding()
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
if (
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
"""set reasoning parser"""
self.processor.apply_reasoning_parser(reasoning_parser)
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
"""apply logits processor to sampler"""
self.processor.add_logits_processor(ids, future, prefill_tokens)
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
def pre_process(self, skip_idx_list: List[int] = []):
"""pre process before running"""
self.processor.pre_process(skip_idx_list)
self.guided_decoding.pre_process(skip_idx_list)
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
"""post process after running"""
self.processor.update_output_tokens(next_tokens, skip_idx_list)
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
def compute_logprobs(
self,
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
skip_idx_list: List[int] = [],
) -> SamplerOutput:
""" """
logits = self.processor.apply_token_mask(logits, skip_idx_list)
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
elif self.logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
for proc in sampling_metadata.logits_processors or []:
logits = proc.apply(logits)
logits = apply_penalty_multi_scores(
sampling_metadata.pre_token_ids,
sampling_metadata.prompt_ids,