mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support logits processors (#4515)
* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor * [chore] fix code style * [fix] add unit test & fix existing bugs * [feat] add engine/worker arg --logits-processors * [fix] redefine user args as logits_processors_args and fix some bugs * [fix] fix test_sampler * Update fastdeploy/model_executor/logits_processor/builtin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/model_executor/logits_processor/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/model_executor/test_logits_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix typo * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * [fix] fix bracelet * [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits * [doc] add docs * [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests * [fix] fix redundant code * [feat] skip apply() if no bias is specified --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -103,6 +103,7 @@ class SamplingParams:
|
||||
bad_words: Optional[List[str]] = None
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||
bad_words_token_ids: Optional[List[int]] = None
|
||||
logits_processors_args: Optional[dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
@@ -136,6 +137,7 @@ class SamplingParams:
|
||||
bad_words=None,
|
||||
guided_decoding=None,
|
||||
bad_words_token_ids=None,
|
||||
logits_processors_args=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
@@ -158,6 +160,7 @@ class SamplingParams:
|
||||
bad_words=bad_words,
|
||||
guided_decoding=guided_decoding,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logits_processors_args=logits_processors_args,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -208,6 +211,24 @@ class SamplingParams:
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
# Verify logits processors arguments
|
||||
if self.logits_processors_args is not None:
|
||||
if self.logits_processors_args.get("logit_bias") is not None:
|
||||
logit_bias = self.logits_processors_args.get("logit_bias")
|
||||
if not isinstance(logit_bias, dict):
|
||||
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
|
||||
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
|
||||
# try to cast the dict to the correct type first
|
||||
try:
|
||||
cast_logit_bias = {}
|
||||
for k, v in logit_bias.items():
|
||||
cast_logit_bias[int(k)] = float(v)
|
||||
self.logits_processors_args["logit_bias"] = cast_logit_bias
|
||||
except:
|
||||
raise TypeError(
|
||||
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
|
||||
Reference in New Issue
Block a user