[Bug fix] Fix bug for seq_len_encoder is 1 (#3467)

This commit is contained in:
chenjian
2025-08-19 15:21:32 +08:00
committed by GitHub
parent aba94169dc
commit d2f6c3b998
3 changed files with 22 additions and 15 deletions

View File

@@ -445,8 +445,8 @@ class MTPSampler(nn.Layer):
sampling_metadata.min_dec_lens,
sampling_metadata.eos_token_ids,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
share_inputs["seq_lens_decoder"],
share_inputs["output_padding_offset"],
share_inputs["output_cum_offsets"],
max_model_len,
)
probs = F.softmax(logits)