[Feature] support logits processors (#4515)

* [feat] provide an interface for logits processors and a builtin LogitBiasLogitsProcessor

* [chore] fix code style

* [fix] add unit test & fix existing bugs

* [feat] add engine/worker arg --logits-processors

* [fix] redefine user args as logits_processors_args and fix some bugs

* [fix] fix test_sampler

* Update fastdeploy/model_executor/logits_processor/builtin.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/model_executor/logits_processor/__init__.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/model_executor/test_logits_processor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix typo

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* [fix] fix bracelet

* [chore] redefine logits processor interface: pass the entire share_inputs into LP, do not copy share_inputs and logits

* [doc] add docs

* [fix] fix logit bias processor not applied when decoding is too fast & add docs and tests

* [fix] fix redundant code

* [feat] skip apply() if no bias is specified

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
李泳桦
2025-10-29 00:08:53 +08:00
committed by GitHub
parent 24b9505971
commit a012e3608b
18 changed files with 882 additions and 14 deletions

View File

@@ -103,6 +103,7 @@ class SamplingParams:
bad_words: Optional[List[str]] = None
guided_decoding: Optional[GuidedDecodingParams] = None
bad_words_token_ids: Optional[List[int]] = None
logits_processors_args: Optional[dict[str, Any]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
@@ -136,6 +137,7 @@ class SamplingParams:
bad_words=None,
guided_decoding=None,
bad_words_token_ids=None,
logits_processors_args=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
@@ -158,6 +160,7 @@ class SamplingParams:
bad_words=bad_words,
guided_decoding=guided_decoding,
bad_words_token_ids=bad_words_token_ids,
logits_processors_args=logits_processors_args,
)
def __post_init__(self):
@@ -208,6 +211,24 @@ class SamplingParams:
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
# Verify logits processors arguments
if self.logits_processors_args is not None:
if self.logits_processors_args.get("logit_bias") is not None:
logit_bias = self.logits_processors_args.get("logit_bias")
if not isinstance(logit_bias, dict):
raise TypeError(f"logit_bias must be a dict, but got {type(logit_bias)}")
elif not all(isinstance(k, int) and isinstance(v, float) for k, v in logit_bias.items()):
# try to cast the dict to the correct type first
try:
cast_logit_bias = {}
for k, v in logit_bias.items():
cast_logit_bias[int(k)] = float(v)
self.logits_processors_args["logit_bias"] = cast_logit_bias
except:
raise TypeError(
"failed to cast logit_bias to the correct {key -> value} type, expected {int -> float}"
)
@dataclass
class BeamSearchParams: