polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -94,54 +94,54 @@ class SamplingParams:
bad_words: Optional[List[str]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
field.name: (req_dict[field.name] if field.name in req_dict else field.default)
for field in fields(cls)
})
}
)
@classmethod
def from_optional(cls,
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
top_k,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None) -> "SamplingParams":
def from_optional(
cls,
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
top_k,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None,
) -> SamplingParams:
"""Create instance from command line arguments"""
return cls(n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty
if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty
if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words)
return cls(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=(presence_penalty if presence_penalty is not None else 0.0),
frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0),
repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0),
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words,
)
def __post_init__(self):
if self.seed is None:
@@ -152,60 +152,44 @@ class SamplingParams:
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(
f"n must be an int, but is of type {type(self.n)}")
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if self.presence_penalty is not None and (
not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
if self.frequency_penalty is not None and (
not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.presence_penalty is not None and (not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.")
if self.frequency_penalty is not None and (not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.")
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
raise ValueError("repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}.")
if self.temperature is not None and self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
raise ValueError(f"temperature must be non-negative, got {self.temperature}.")
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
# quietly accept -1 as disabled, but prefer 0
if self.top_k < -1:
raise ValueError(f"top_k must be 0 (disable), or at least 1, "
f"got {self.top_k}.")
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(
f"top_k must be an integer, got {type(self.top_k).__name__}")
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
raise ValueError(
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
raise ValueError(f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
raise ValueError(f"min_tokens must be greater than or equal to 0, " f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20:
raise ValueError(
"Invalid value for 'top_logprobs': must be less than or equal to 20.")
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got "
f"{self.seed}.")
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""
@@ -218,6 +202,7 @@ class SamplingParams:
@dataclass
class BeamSearchParams:
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False