[Feature] support stop_token_ids (#5399)

* support stop_token_ids

* fix

* delete chinese

* support both

* delete print
This commit is contained in:
lizexu123
2025-12-09 17:49:12 +08:00
committed by GitHub
parent df67379bc3
commit 95eab9f9ee
18 changed files with 377 additions and 127 deletions

View File

@@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
@@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (step_idx_now >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
@@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
const paddle::Tensor &end_ids,
const paddle::Tensor &min_tokens) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
@@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
pre_ids_len);
}
@@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
"end_ids",
"min_tokens"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})