mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support logits processors (#4515)
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor * [chore] fix code style * [fix] add unit test & fix existing bugs * [feat] add engine/worker arg --logits-processors * [fix] redefine user args as logits_processors_args and fix some bugs * [fix] fix test_sampler * Update fastdeploy/model_executor/logits_processor/builtin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/logits_processor/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/model_executor/test_logits_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix typo * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix bracelet * [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits * [doc] add docs * [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests * [fix] fix redundant code * [feat] skip apply() if no bias is specified --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -53,9 +53,9 @@ def top_p_normalize_probs_paddle(
|
||||
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||
|
||||
|
||||
class SamplerProcessor:
|
||||
class GuidedDecoding:
|
||||
"""
|
||||
SamplingProcessor for guided decoding.
|
||||
processor for guided decoding.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -75,7 +75,7 @@ class SamplerProcessor:
|
||||
future: Optional[Any] = None,
|
||||
prefill_tokens: List[int] = [],
|
||||
):
|
||||
"""add logits processor to SamplerProcessor"""
|
||||
"""add logits processor to GuidedDecoding"""
|
||||
with self.logits_lock:
|
||||
if future is None:
|
||||
if ids in self.logits_processor:
|
||||
@@ -216,7 +216,7 @@ class Sampler(nn.Layer):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.processor = SamplerProcessor()
|
||||
self.guided_decoding = GuidedDecoding()
|
||||
self.logprobs_mode = fd_config.model_config.logprobs_mode if fd_config is not None else logprobs_mode
|
||||
# Can only be created when fd_config.early_stopper_config.enable_early_stop = True
|
||||
if (
|
||||
@@ -230,19 +230,19 @@ class Sampler(nn.Layer):
|
||||
|
||||
def set_reasoning_parser(self, reasoning_parser: Optional[ReasoningParser] = None):
|
||||
"""set reasoning parser"""
|
||||
self.processor.apply_reasoning_parser(reasoning_parser)
|
||||
self.guided_decoding.apply_reasoning_parser(reasoning_parser)
|
||||
|
||||
def apply_logits_processor(self, ids: int, future: Optional[Any] = None, prefill_tokens: List[int] = []):
|
||||
"""apply logits processor to sampler"""
|
||||
self.processor.add_logits_processor(ids, future, prefill_tokens)
|
||||
self.guided_decoding.add_logits_processor(ids, future, prefill_tokens)
|
||||
|
||||
def pre_process(self, skip_idx_list: List[int] = []):
|
||||
"""pre process before running"""
|
||||
self.processor.pre_process(skip_idx_list)
|
||||
self.guided_decoding.pre_process(skip_idx_list)
|
||||
|
||||
def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []):
|
||||
"""post process after running"""
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
self.guided_decoding.update_output_tokens(next_tokens, skip_idx_list)
|
||||
|
||||
def compute_logprobs(
|
||||
self,
|
||||
@@ -332,7 +332,7 @@ class Sampler(nn.Layer):
|
||||
skip_idx_list: List[int] = [],
|
||||
) -> SamplerOutput:
|
||||
""" """
|
||||
logits = self.processor.apply_token_mask(logits, skip_idx_list)
|
||||
logits = self.guided_decoding.apply_token_mask(logits, skip_idx_list)
|
||||
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
@@ -341,6 +341,9 @@ class Sampler(nn.Layer):
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
for proc in sampling_metadata.logits_processors or []:
|
||||
logits = proc.apply(logits)
|
||||
|
||||
logits = apply_penalty_multi_scores(
|
||||
sampling_metadata.pre_token_ids,
|
||||
sampling_metadata.prompt_ids,
|
||||
|
||||
Reference in New Issue
Block a user