mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-03 15:56:49 +08:00
[stop sequence] support stop sequence (#3025)
* stop seqs in multi-ends * unittest for gpu stop op * kernel tid==0
This commit is contained in:
@@ -266,13 +266,12 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &next_tokens,
|
||||
const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const bool beam_search);
|
||||
|
||||
void GetStopFlagsMultiSeqs(
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &pre_ids,
|
||||
const paddle::Tensor &step_idx, const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor &seq_lens, const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len, const paddle::Tensor &end_ids);
|
||||
|
||||
void UpdateInputes(const paddle::Tensor &stop_flags,
|
||||
const paddle::Tensor ¬_need_stop, // only on cpu
|
||||
@@ -954,12 +953,6 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
|
||||
m.def("set_stop_value_multi_ends", &GetStopFlagsMulti,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* stop_generation_multi_stop_seqs.cu
|
||||
* set_stop_value_multi_seqs
|
||||
*/
|
||||
m.def("set_stop_value_multi_seqs", &GetStopFlagsMultiSeqs,
|
||||
"update_inputs function");
|
||||
|
||||
/**
|
||||
* update_inputs.cu
|
||||
|
Reference in New Issue
Block a user