From 5585cf7aa5822c37feaa41285efa45cc3a2c4dd8 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Mon, 18 Aug 2025 16:12:42 +0800 Subject: [PATCH] fix mtp_rej_topp input (#3450) --- .../model_executor/layers/sample/ops/top_k_top_p_sampling.py | 4 ++-- fastdeploy/model_executor/layers/sample/sampler.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index 2b0e522cc..ad8058df0 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -141,7 +141,7 @@ def rejection_top_p_sampling( top_k_renorm_probs, ) - if not any(x > 0 for x in top_k_list): + if top_k_list and not any(x > 0 for x in top_k_list): ids = rejection_top_p_sampling( x, top_p, @@ -177,7 +177,7 @@ def min_p_sampling( """ min_p_sampling """ - if not any(x > 0 for x in min_p_arr_cpu): + if min_p_arr_cpu and not any(x > 0 for x in min_p_arr_cpu): return probs else: if current_platform.is_cuda(): diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 1cc26e4fb..5f7a7d157 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -457,5 +457,7 @@ class MTPSampler(nn.Layer): ) probs = F.softmax(logits) - _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) + _, next_tokens = top_k_top_p_sampling( + probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list + ) return next_tokens