mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -15,9 +15,10 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Optional, Union, List
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -62,6 +63,7 @@ class SamplingParams:
|
||||
token sequence is not allowed when the next generated token
|
||||
can complete the sequence.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
reasoning_max_tokens: Maximum number of tokens to generate for reasoning per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
@@ -75,131 +77,107 @@ class SamplingParams:
|
||||
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 0.7
|
||||
presence_penalty: float = None
|
||||
frequency_penalty: float = None
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
max_tokens: Optional[int] = None
|
||||
reasoning_max_tokens: Optional[int] = None
|
||||
min_tokens: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
bad_words: Optional[List[str]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, req_dict: dict[str, Any]) -> "SamplingParams":
|
||||
"""Create a SamplingParams instance from a dictionary.
|
||||
|
||||
Args:
|
||||
req_dict: Dictionary containing sampling parameters where keys match
|
||||
the field names of SamplingParams
|
||||
|
||||
Returns:
|
||||
SamplingParams: A new instance initialized with values from the dictionary
|
||||
"""
|
||||
return cls(**{
|
||||
field.name: req_dict[field.name] if field.name in req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(
|
||||
**{
|
||||
field.name:
|
||||
req_dict[field.name] if field.name in
|
||||
req_dict else field.default
|
||||
for field in fields(cls)
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_optional(cls,
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None
|
||||
) -> "SamplingParams":
|
||||
"""Create a SamplingParams instance from optional arguments with default fallbacks.
|
||||
|
||||
Args:
|
||||
n: Number of output sequences (default: 1)
|
||||
best_of: Number of sequences to generate before selecting best (default: None)
|
||||
presence_penalty: Penalty for new tokens (default: 0.0)
|
||||
frequency_penalty: Penalty based on token frequency (default: 0.0)
|
||||
repetition_penalty: Penalty for repeated tokens (default: 1.0)
|
||||
temperature: Sampling temperature (default: 1.0)
|
||||
top_p: Nucleus sampling probability (default: 0.7)
|
||||
seed: Random seed (default: random)
|
||||
stop: Stop sequences (default: None)
|
||||
stop_token_ids: Stop token IDs (default: None)
|
||||
max_tokens: Maximum tokens to generate (default: 8192)
|
||||
min_tokens: Minimum tokens before stopping (default: 1)
|
||||
logprobs: Number of logprobs to return (default: None)
|
||||
bad_words: List of banned words (default: None)
|
||||
|
||||
Returns:
|
||||
SamplingParams: A new instance with provided or default values
|
||||
"""
|
||||
return cls(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=presence_penalty if presence_penalty is not None else 0.0,
|
||||
frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0,
|
||||
repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words
|
||||
)
|
||||
|
||||
n,
|
||||
best_of,
|
||||
presence_penalty,
|
||||
frequency_penalty,
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
max_tokens=None,
|
||||
reasoning_max_tokens=None,
|
||||
min_tokens=1,
|
||||
logprobs=None,
|
||||
bad_words=None) -> "SamplingParams":
|
||||
"""Create instance from command line arguments"""
|
||||
return cls(n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=presence_penalty
|
||||
if presence_penalty is not None else 0.0,
|
||||
frequency_penalty=frequency_penalty
|
||||
if frequency_penalty is not None else 0.0,
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
max_tokens=max_tokens if max_tokens is not None else 8192,
|
||||
reasoning_max_tokens=reasoning_max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
bad_words=bad_words)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize sampling parameters after instance creation.
|
||||
|
||||
Sets a random seed if none provided and validates all parameters.
|
||||
"""
|
||||
if self.seed is None:
|
||||
self.seed = random.randint(0, 922337203685477580)
|
||||
if self.max_tokens is not None and self.reasoning_max_tokens is None:
|
||||
self.reasoning_max_tokens = max(int(self.max_tokens * 0.8), 1)
|
||||
self._verify_args()
|
||||
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
"""Validate all sampling parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is outside its valid range or of incorrect type
|
||||
"""
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
raise ValueError(
|
||||
f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
if self.presence_penalty is not None and (
|
||||
not -2.0 <= self.presence_penalty <= 2.0):
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
if self.frequency_penalty is not None and (
|
||||
not -2.0 <= self.frequency_penalty <= 2.0):
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if self.repetition_penalty <= 0.0:
|
||||
if self.repetition_penalty is not None and self.repetition_penalty <= 0.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be greater than zero, got "
|
||||
f"{self.repetition_penalty}.")
|
||||
if self.temperature < 0.0:
|
||||
if self.temperature is not None and self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
if not 0.0 <= self.top_p <= 1.0:
|
||||
if self.top_p is not None and not 0.0 <= self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in [0, 1], got {self.top_p}.")
|
||||
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
|
||||
if self.reasoning_max_tokens is not None and self.reasoning_max_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"reasoning_max_tokens must be less than max_tokens, got {self.reasoning_max_tokens}.")
|
||||
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
@@ -215,33 +193,17 @@ class SamplingParams:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got "
|
||||
f"{self.seed}.")
|
||||
|
||||
|
||||
def update_from_tokenizer(self, tokenizer):
|
||||
"""Update sampling parameters based on tokenizer configuration.
|
||||
|
||||
Note: Currently a placeholder for future implementation of:
|
||||
- Stop tokens handling
|
||||
- Bad words filtering
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer instance to use for configuration
|
||||
"""
|
||||
# TODO: Implement stop tokens and bad words support
|
||||
# Currently stop tokens and bad words are not supported yet
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchParams:
|
||||
"""Parameters for beam search text generation.
|
||||
|
||||
Args:
|
||||
beam_width: Number of beams to maintain during search
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
ignore_eos: Whether to ignore EOS tokens (default: False)
|
||||
temperature: Sampling temperature (0 means greedy, default: 0.0)
|
||||
length_penalty: Penalty applied to length (1.0 means no penalty, default: 1.0)
|
||||
include_stop_str_in_output: Whether to include stop strings in output (default: False)
|
||||
"""
|
||||
"""Beam search parameters for text generation."""
|
||||
beam_width: int
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
|
||||
Reference in New Issue
Block a user