[Feature 2.0.2] support top_k_top_p sampling (#2789)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* change func name
This commit is contained in:
Sunny-bot1
2025-07-10 12:01:51 +08:00
committed by GitHub
parent 1fe37cb7e8
commit 1107e08cd9
18 changed files with 524 additions and 134 deletions

View File

@@ -27,7 +27,7 @@ from fastdeploy.model_executor.guided_decoding.base_guided_decoding import \
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores, apply_speculative_penalty_multi_scores,
top_p_sampling)
top_k_top_p_sampling)
from fastdeploy.platforms import current_platform
@@ -213,7 +213,7 @@ class Sampler(nn.Layer):
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
self.processor.update_output_tokens(next_tokens, skip_idx_list)
return next_tokens
@@ -364,5 +364,5 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_p_sampling(probs, sampling_metadata.top_p)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
return next_tokens