mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
@@ -53,6 +53,8 @@ class SamplingMetadata:
|
||||
stop_flags: Optional[paddle.Tensor] = None
|
||||
prompt_ids: Optional[paddle.Tensor] = None
|
||||
prompt_lens: Optional[paddle.Tensor] = None
|
||||
temp_scaled_logprobs_flag: Optional[bool] = None
|
||||
top_p_normalized_logprobs_flag: Optional[bool] = None
|
||||
temp_scaled_logprobs: Optional[paddle.Tensor] = None
|
||||
top_p_normalized_logprobs: Optional[paddle.Tensor] = None
|
||||
share_inputs: Optional[Dict[str, paddle.Tensor]] = None
|
||||
|
||||
@@ -375,7 +375,7 @@ class Sampler(nn.Layer):
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
if temp_scaled_logprobs is not None:
|
||||
if temp_scaled_logprobs is not None and sampling_metadata.temp_scaled_logprobs_flag:
|
||||
real_bsz_temp_scaled = temp_scaled_logprobs[:real_bsz]
|
||||
temperature = sampling_metadata.temperature[:real_bsz]
|
||||
temp_temperature = paddle.where(real_bsz_temp_scaled, temperature, paddle.ones_like(temperature))
|
||||
@@ -385,7 +385,11 @@ class Sampler(nn.Layer):
|
||||
top_p_logprob = None
|
||||
top_p_req_mask = None
|
||||
|
||||
if top_p_normalized_logprobs is not None and share_inputs is not None:
|
||||
if (
|
||||
top_p_normalized_logprobs is not None
|
||||
and share_inputs is not None
|
||||
and sampling_metadata.top_p_normalized_logprobs_flag
|
||||
):
|
||||
seq_lens_this_time = share_inputs["seq_lens_this_time"].reshape([-1, 1])[:real_bsz]
|
||||
seq_lens_encoder = share_inputs["seq_lens_encoder"].reshape([-1, 1])[:real_bsz]
|
||||
seq_lens_decoder = share_inputs["seq_lens_decoder"].reshape([-1, 1])[:real_bsz]
|
||||
|
||||
Reference in New Issue
Block a user