mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 20:54:03 +08:00
[stop sequence] support stop sequence (#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
This commit is contained in:
@@ -210,15 +210,29 @@ def post_process_normal(
|
||||
paddle.logical_or(model_output.stop_flags, length_cond),
|
||||
model_output.stop_flags,
|
||||
)
|
||||
# TODO(gongshaotian): Add use_stop_seqs
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
False,
|
||||
) # multi ends
|
||||
|
||||
if current_platform.is_cuda():
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
model_output.pre_ids,
|
||||
model_output.step_idx,
|
||||
model_output.stop_token_ids,
|
||||
model_output.stop_seqs_len,
|
||||
False,
|
||||
) # multi ends
|
||||
else:
|
||||
set_stop_value_multi_ends(
|
||||
sampler_output.sampled_token_ids,
|
||||
model_output.stop_flags,
|
||||
model_output.seq_lens_this_time,
|
||||
model_output.eos_token_id,
|
||||
model_output.next_tokens,
|
||||
False,
|
||||
)
|
||||
|
||||
# 2. Update the input buffer of the model
|
||||
with paddle.framework._no_check_dy2st_diff():
|
||||
|
||||
Reference in New Issue
Block a user