[stop sequence] support stop sequence (#3025)

* stop seqs in multi-ends

* unittest for gpu stop op

* kernel tid==0
This commit is contained in:
JYChen
2025-07-29 14:17:37 +08:00
committed by GitHub
parent 1a815b7a2a
commit dafe02a7b9
11 changed files with 193 additions and 189 deletions

View File

@@ -210,15 +210,29 @@ def post_process_normal(
paddle.logical_or(model_output.stop_flags, length_cond),
model_output.stop_flags,
)
# TODO(gongshaotian): Add use_stop_seqs
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
False,
) # multi ends
if current_platform.is_cuda():
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
model_output.pre_ids,
model_output.step_idx,
model_output.stop_token_ids,
model_output.stop_seqs_len,
False,
) # multi ends
else:
set_stop_value_multi_ends(
sampler_output.sampled_token_ids,
model_output.stop_flags,
model_output.seq_lens_this_time,
model_output.eos_token_id,
model_output.next_tokens,
False,
)
# 2. Update the input buffer of the model
with paddle.framework._no_check_dy2st_diff():