[Optimization] compulte real max_logprobs in batch (#5430)

This commit is contained in:
chen
2025-12-09 14:15:05 +08:00
committed by GitHub
parent f7e832efaf
commit 76649b45c1
4 changed files with 48 additions and 6 deletions

View File

@@ -375,7 +375,7 @@ class Sampler(nn.Layer):
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
share_inputs = sampling_metadata.share_inputs
if temp_scaled_logprobs is not None:
if temp_scaled_logprobs is not None and sampling_metadata.temp_scaled_logprobs_flag:
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
temperature = sampling_metadata.temperature[:real_bsz]
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
@@ -385,7 +385,11 @@ class Sampler(nn.Layer):
top_p_logprob = None
top_p_req_mask = None
if top_p_normalized_logprobs is not None and share_inputs is not None:
if (
top_p_normalized_logprobs is not None
and share_inputs is not None
and sampling_metadata.top_p_normalized_logprobs_flag
):
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]