mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-30 14:22:27 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -94,54 +94,54 @@ class SamplingParams:
|
||||
bad_words: Optional[List[str]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
**{
|
||||
field.name:
|
||||
req_dict[field.name] if field.name in
|
||||
req_dict else field.default
|
||||
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
|
||||
for field in fields(cls)
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_optional(cls,
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None) -> "SamplingParams":
|
||||
def from_optional(
|
||||
cls,
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None,
|
||||
) -> SamplingParams:
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=presence_penalty
|
||||
if presence_penalty is not None else 0.0,
|
||||
frequency_penalty=frequency_penalty
|
||||
if frequency_penalty is not None else 0.0,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
reasoning_max_tokens=reasoning_max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words)
|
||||
return cls(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=(presence_penalty if presence_penalty is not None else 0.0),
|
||||
frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0),
|
||||
repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0),
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
reasoning_max_tokens=reasoning_max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.seed is None:
|
||||
@@ -152,60 +152,44 @@ class SamplingParams:
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(
|
||||
f"n must be an int, but is of type {type(self.n)}")
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if self.presence_penalty is not None and (
|
||||
not -2.0 <= self.presence_penalty <= 2.0):
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
if self.frequency_penalty is not None and (
|
||||
not -2.0 <= self.frequency_penalty <= 2.0):
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if self.presence_penalty is not None and (not -2.0 <= self.presence_penalty <= 2.0):
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.")
|
||||
if self.frequency_penalty is not None and (not -2.0 <= self.frequency_penalty <= 2.0):
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.")
|
||||
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be greater than zero, got "
|
||||
f"{self.repetition_penalty}.")
|
||||
raise ValueError("repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}.")
|
||||
if self.temperature is not None and self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
|
||||
if not isinstance(self.top_k, int):
|
||||
raise TypeError(
|
||||
f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
|
||||
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
|
||||
raise ValueError(f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
|
||||
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, " f"got {self.min_tokens}.")
|
||||
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
raise ValueError(
|
||||
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
f"{self.seed}.")
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""
|
||||
@@ -218,6 +202,7 @@ class SamplingParams:
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
"""Beam search parameters for text generation."""
|
||||
|
||||
beam_width: int
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
|
Reference in New Issue
Block a user