mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 21:32:29 +08:00
[Feature] support top_k_top_p sampling (#2753)
* support top_k_top_p sampling * fix * add api param * add api para * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -52,6 +52,7 @@ class SamplingParams:
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
|
||||
seed: Random seed to use for the generation.
|
||||
stop: list of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
@@ -81,7 +82,8 @@ class SamplingParams:
|
||||
frequency_penalty: float = None
|
||||
repetition_penalty: float = None
|
||||
temperature: float = None
|
||||
top_p: float = None
|
||||
top_p: float = 1.0
|
||||
top_k: int = 0
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
@@ -111,6 +113,7 @@ class SamplingParams:
|
||||
repetition_penalty,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stop_token_ids=None,
|
||||
@@ -129,7 +132,8 @@ class SamplingParams:
|
||||
repetition_penalty=repetition_penalty
|
||||
if repetition_penalty is not None else 1.0,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
top_p=top_p if top_p is not None else 0.7,
|
||||
top_p=top_p if top_p is not None else 1.0,
|
||||
top_k=top_k if top_k is not None else 0,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
|
Reference in New Issue
Block a user