mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
fix mtp_rej_topp input (#3450)
This commit is contained in:
@@ -141,7 +141,7 @@ def rejection_top_p_sampling(
|
|||||||
top_k_renorm_probs,
|
top_k_renorm_probs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not any(x > 0 for x in top_k_list):
|
if top_k_list and not any(x > 0 for x in top_k_list):
|
||||||
ids = rejection_top_p_sampling(
|
ids = rejection_top_p_sampling(
|
||||||
x,
|
x,
|
||||||
top_p,
|
top_p,
|
||||||
@@ -177,7 +177,7 @@ def min_p_sampling(
|
|||||||
"""
|
"""
|
||||||
min_p_sampling
|
min_p_sampling
|
||||||
"""
|
"""
|
||||||
if not any(x > 0 for x in min_p_arr_cpu):
|
if min_p_arr_cpu and not any(x > 0 for x in min_p_arr_cpu):
|
||||||
return probs
|
return probs
|
||||||
else:
|
else:
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
|
@@ -457,5 +457,7 @@ class MTPSampler(nn.Layer):
|
|||||||
)
|
)
|
||||||
probs = F.softmax(logits)
|
probs = F.softmax(logits)
|
||||||
|
|
||||||
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
|
_, next_tokens = top_k_top_p_sampling(
|
||||||
|
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||||
|
)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
Reference in New Issue
Block a user