fix mtp_rej_topp input (#3450)

This commit is contained in:
chen
2025-08-18 16:12:42 +08:00
committed by GitHub
parent 246cd7b3a5
commit 5585cf7aa5
2 changed files with 5 additions and 3 deletions

View File

@@ -457,5 +457,7 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
)
return next_tokens