fix mtp_rej_topp input (#3450)

This commit is contained in:
chen
2025-08-18 16:12:42 +08:00
committed by GitHub
parent 246cd7b3a5
commit 5585cf7aa5
2 changed files with 5 additions and 3 deletions

View File

@@ -141,7 +141,7 @@ def rejection_top_p_sampling(
top_k_renorm_probs, top_k_renorm_probs,
) )
if not any(x > 0 for x in top_k_list): if top_k_list and not any(x > 0 for x in top_k_list):
ids = rejection_top_p_sampling( ids = rejection_top_p_sampling(
x, x,
top_p, top_p,
@@ -177,7 +177,7 @@ def min_p_sampling(
""" """
min_p_sampling min_p_sampling
""" """
if not any(x > 0 for x in min_p_arr_cpu): if min_p_arr_cpu and not any(x > 0 for x in min_p_arr_cpu):
return probs return probs
else: else:
if current_platform.is_cuda(): if current_platform.is_cuda():

View File

@@ -457,5 +457,7 @@ class MTPSampler(nn.Layer):
) )
probs = F.softmax(logits) probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) _, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
return next_tokens return next_tokens