mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support top_k_top_p sampling (#2753)
* support top_k_top_p sampling * fix * add api param * add api para * fix * fix * fix * fix * fix * fix * fix
This commit is contained in:
@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
|
||||
self.processor.update_output_tokens(next_tokens, skip_idx_list)
|
||||
return next_tokens
|
||||
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
|
||||
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
||||
return next_tokens
|
||||
|
||||
Reference in New Issue
Block a user