[Feature] support top_k_top_p sampling (#2753)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 11:58:58 +08:00
committed by GitHub
parent b0f525955c
commit e45050cae3
15 changed files with 501 additions and 53 deletions

View File

@@ -214,7 +214,7 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
return next_tokens
@@ -367,5 +367,5 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
return next_tokens