mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support stop_token_ids (#5399)
* support stop_token_ids * fix * delete chinese * support both * delete print
This commit is contained in:
@@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
const int accept_tokens_len,
|
||||
const int stop_seqs_bs,
|
||||
const int stop_seqs_max_len,
|
||||
const int64_t *min_tokens,
|
||||
const int pre_ids_len) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
@@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
|
||||
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
|
||||
const int accept_num = accept_nums[bid];
|
||||
const int64_t step_idx_now = step_idx[bid];
|
||||
const int64_t min_token_limit = min_tokens[bid];
|
||||
|
||||
const bool can_stop = (step_idx_now >= min_token_limit);
|
||||
if (!can_stop) return;
|
||||
if (!stop_flags[bid]) {
|
||||
int accept_idx = 0;
|
||||
bool is_end = false;
|
||||
@@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||
const paddle::Tensor &seq_lens,
|
||||
const paddle::Tensor &stop_seqs,
|
||||
const paddle::Tensor &stop_seqs_len,
|
||||
const paddle::Tensor &end_ids) {
|
||||
const paddle::Tensor &end_ids,
|
||||
const paddle::Tensor &min_tokens) {
|
||||
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
|
||||
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
|
||||
|
||||
@@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
|
||||
accept_tokens_len,
|
||||
stop_seqs_bs,
|
||||
stop_seqs_max_len,
|
||||
min_tokens.data<int64_t>(),
|
||||
pre_ids_len);
|
||||
}
|
||||
|
||||
@@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
|
||||
"seq_lens",
|
||||
"stop_seqs",
|
||||
"stop_seqs_len",
|
||||
"end_ids"})
|
||||
"end_ids",
|
||||
"min_tokens"})
|
||||
.Outputs({"accept_tokens_out", "stop_flags_out"})
|
||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||
{"stop_flags", "stop_flags_out"}})
|
||||
|
||||
Reference in New Issue
Block a user