mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
Rename top_p_sampling to top_k_top_p_sampling (#2791)
This commit is contained in:
@@ -16,10 +16,10 @@
|
||||
|
||||
from .apply_penalty_multi_scores import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores)
|
||||
from .top_p_sampling import top_p_sampling
|
||||
from .top_k_top_p_sampling import top_k_top_p_sampling
|
||||
|
||||
__all__ = [
|
||||
"apply_penalty_multi_scores",
|
||||
"apply_speculative_penalty_multi_scores",
|
||||
"top_p_sampling",
|
||||
"top_k_top_p_sampling",
|
||||
]
|
||||
|
@@ -25,7 +25,7 @@ if current_platform.is_gcu():
|
||||
from fastdeploy.model_executor.ops.gcu import \
|
||||
top_p_sampling as gcu_top_p_sampling
|
||||
|
||||
def top_p_sampling(
|
||||
def top_k_top_p_sampling(
|
||||
x: paddle.Tensor,
|
||||
top_p: paddle.Tensor,
|
||||
top_k: Optional[paddle.Tensor] = None,
|
@@ -27,7 +27,7 @@ from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.ops import (
|
||||
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
|
||||
top_p_sampling)
|
||||
top_k_top_p_sampling)
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
return next_tokens
|
||||
|
Reference in New Issue
Block a user