Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -15,9 +15,10 @@
"""
from __future__ import annotations
from dataclasses import dataclass, fields
from typing import Any, Optional, Union, List
import random
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union
@dataclass
@@ -62,6 +63,7 @@ class SamplingParams:
token sequence is not allowed when the next generated token
can complete the sequence.
max_tokens: Maximum number of tokens to generate per output sequence.
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
@@ -75,131 +77,107 @@ class SamplingParams:
n: int = 1
best_of: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 0.7
presence_penalty: float = None
frequency_penalty: float = None
repetition_penalty: float = None
temperature: float = None
top_p: float = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
max_tokens: Optional[int] = 16
max_tokens: Optional[int] = None
reasoning_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
bad_words: Optional[List[str]] = None
@classmethod
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
"""Create a SamplingParams instance from a dictionary.
Args:
req_dict: Dictionary containing sampling parameters where keys match
the field names of SamplingParams
Returns:
SamplingParams: A new instance initialized with values from the dictionary
"""
return cls(**{
field.name: req_dict[field.name] if field.name in req_dict else field.default
for field in fields(cls)
})
"""Create instance from command line arguments"""
return cls(
**{
field.name:
req_dict[field.name] if field.name in
req_dict else field.default
for field in fields(cls)
})
@classmethod
def from_optional(cls,
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None
) -> "SamplingParams":
"""Create a SamplingParams instance from optional arguments with default fallbacks.
Args:
n: Number of output sequences (default: 1)
best_of: Number of sequences to generate before selecting best (default: None)
presence_penalty: Penalty for new tokens (default: 0.0)
frequency_penalty: Penalty based on token frequency (default: 0.0)
repetition_penalty: Penalty for repeated tokens (default: 1.0)
temperature: Sampling temperature (default: 1.0)
top_p: Nucleus sampling probability (default: 0.7)
seed: Random seed (default: random)
stop: Stop sequences (default: None)
stop_token_ids: Stop token IDs (default: None)
max_tokens: Maximum tokens to generate (default: 8192)
min_tokens: Minimum tokens before stopping (default: 1)
logprobs: Number of logprobs to return (default: None)
bad_words: List of banned words (default: None)
Returns:
SamplingParams: A new instance with provided or default values
"""
return cls(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words
)
n,
best_of,
presence_penalty,
frequency_penalty,
repetition_penalty,
temperature,
top_p,
seed=None,
stop=None,
stop_token_ids=None,
max_tokens=None,
reasoning_max_tokens=None,
min_tokens=1,
logprobs=None,
bad_words=None) -> "SamplingParams":
"""Create instance from command line arguments"""
return cls(n=1 if n is None else n,
best_of=best_of,
presence_penalty=presence_penalty
if presence_penalty is not None else 0.0,
frequency_penalty=frequency_penalty
if frequency_penalty is not None else 0.0,
repetition_penalty=repetition_penalty
if repetition_penalty is not None else 1.0,
temperature=temperature if temperature is not None else 1.0,
top_p=top_p if top_p is not None else 0.7,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
max_tokens=max_tokens if max_tokens is not None else 8192,
reasoning_max_tokens=reasoning_max_tokens,
min_tokens=min_tokens,
logprobs=logprobs,
bad_words=bad_words)
def __post_init__(self):
"""Initialize sampling parameters after instance creation.
Sets a random seed if none provided and validates all parameters.
"""
if self.seed is None:
self.seed = random.randint(0, 922337203685477580)
if self.max_tokens is not None and self.reasoning_max_tokens is None:
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
self._verify_args()
def _verify_args(self) -> None:
"""Validate all sampling parameters.
Raises:
ValueError: If any parameter is outside its valid range or of incorrect type
"""
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
raise ValueError(
f"n must be an int, but is of type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not -2.0 <= self.presence_penalty <= 2.0:
if self.presence_penalty is not None and (
not -2.0 <= self.presence_penalty <= 2.0):
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
if not -2.0 <= self.frequency_penalty <= 2.0:
if self.frequency_penalty is not None and (
not -2.0 <= self.frequency_penalty <= 2.0):
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.repetition_penalty <= 0.0:
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0:
if self.temperature is not None and self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
if not 0.0 <= self.top_p <= 1.0:
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
raise ValueError(
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
@@ -215,33 +193,17 @@ class SamplingParams:
raise ValueError("seed must be in [0, 922337203685477580], got "
f"{self.seed}.")
def update_from_tokenizer(self, tokenizer):
"""Update sampling parameters based on tokenizer configuration.
Note: Currently a placeholder for future implementation of:
- Stop tokens handling
- Bad words filtering
Args:
tokenizer: The tokenizer instance to use for configuration
"""
# TODO: Implement stop tokens and bad words support
# Currently stop tokens and bad words are not supported yet
"""
pass
@dataclass
class BeamSearchParams:
"""Parameters for beam search text generation.
Args:
beam_width: Number of beams to maintain during search
max_tokens: Maximum number of tokens to generate
ignore_eos: Whether to ignore EOS tokens (default: False)
temperature: Sampling temperature (0 means greedy, default: 0.0)
length_penalty: Penalty applied to length (1.0 means no penalty, default: 1.0)
include_stop_str_in_output: Whether to include stop strings in output (default: False)
"""
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False