[BugFix] fix speculate_limit_thinking_content_length (#5590)

* fix speculate_limit_thinking_content_length

* update
This commit is contained in:
Yuanle Liu
2025-12-16 20:31:45 +08:00
committed by GitHub
parent 7140939c51
commit 867803ae10
5 changed files with 1 additions and 57 deletions

View File

@@ -22,7 +22,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
const int64_t* eos_token_ids,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int64_t think_end_id,
const int tokens_per_step,
@@ -106,7 +105,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
@@ -119,7 +117,6 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id) {
@@ -134,7 +131,6 @@ void SpeculateLimitThinkingContentLengthV1(
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<bool*>(stop_flags.data<bool>()),
think_end_id,
tokens_per_step,
@@ -148,7 +144,6 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})

View File

@@ -25,7 +25,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int64_t* step_idx,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
@@ -115,7 +114,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
@@ -128,7 +126,6 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
@@ -141,7 +138,6 @@ void SpeculateLimitThinkingContentLengthV2(
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
@@ -155,7 +151,6 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})