[BugFix] fix thinking bug (#4710)

* fix thinking bug

* fix ut

* update

* fix
This commit is contained in:
Yuanle Liu
2025-10-31 22:00:31 +08:00
committed by GitHub
parent 10358bf1a0
commit b301bd6c31
8 changed files with 458 additions and 290 deletions

View File

@@ -19,81 +19,98 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int64_t think_end_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int bs,
const int eos_token_id_len) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
if (current_limit_think_status == 2 && stop_flags[bid]) {
return;
}
int new_accept_num = original_accept_num;
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else {
// 检查是否生成了EOS
for (int i = 0; i < eos_token_id_len; i++) {
if (eos_token_ids[i] == next_token) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
if (stop_flags[bid]) {
stop_flags[bid] = false;
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV1(
@@ -103,20 +120,26 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int eos_token_id_len = eos_token_ids.shape()[0];
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
tokens_per_step,
batch_size);
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<bool*>(stop_flags.data<bool>()),
think_end_id,
tokens_per_step,
batch_size,
eos_token_id_len);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
@@ -125,7 +148,9 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
"seq_lens_decoder",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -26,99 +26,100 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3 && stop_flags[bid]) {
return;
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
condition_triggered = true; // 因为修改了token需要截断
}
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
condition_triggered = true; // 因为修改了token需要截断
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
next_tokens[token_idx] = next_token;
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV2(
@@ -128,22 +129,24 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
line_break_id,
tokens_per_step,
batch_size);
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
tokens_per_step,
batch_size);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
@@ -152,7 +155,8 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
"seq_lens_decoder",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})