[Feature] support stop_token_ids (#5399)

* support stop_token_ids

* fix

* delete chinese

* support both

* delete print
This commit is contained in:
lizexu123
2025-12-09 17:49:12 +08:00
committed by GitHub
parent df67379bc3
commit 95eab9f9ee
18 changed files with 377 additions and 127 deletions

View File

@@ -420,6 +420,7 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& min_tokens,
const bool beam_search);
void UpdateInputs(const paddle::Tensor& stop_flags,
@@ -764,7 +765,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
const paddle::Tensor& seq_lens,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& end_ids);
const paddle::Tensor& end_ids,
const paddle::Tensor& min_tokens);
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& accept_tokens,

View File

@@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
@@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (step_idx_now >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
@@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
const paddle::Tensor &end_ids,
const paddle::Tensor &min_tokens) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
@@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
pre_ids_len);
}
@@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
"end_ids",
"min_tokens"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})

View File

@@ -37,59 +37,67 @@ __global__ void set_value_by_flags(bool *stop_flags,
const int *stop_seqs_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
bool beam_search,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
if(tid == 0){
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
const int64_t current_step = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (current_step >= min_token_limit);
if (tid == 0) {
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && can_stop &&
is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
}
if (!can_stop) return;
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now =
stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];
bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
}
}
}
void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
@@ -101,50 +109,63 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &min_tokens,
const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
beam_search,
prefill_one_step_stop);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
beam_search,
prefill_one_step_stop);
}
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({"topk_ids",
"stop_flags",
"seq_lens",
"end_ids",
"next_tokens",
"pre_ids",
"step_idx",
"stop_seqs",
"stop_seqs_len",
"min_tokens"})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},