[BugFix] fix thinking bug (#4710)

* fix thinking bug

* fix ut

* update

* fix
This commit is contained in:
Yuanle Liu
2025-10-31 22:00:31 +08:00
committed by GitHub
parent 10358bf1a0
commit b301bd6c31
8 changed files with 458 additions and 290 deletions

View File

@@ -987,12 +987,15 @@ void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id);
void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id);
@@ -1003,6 +1006,8 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id);
void SpeculateLimitThinkingContentLengthV2(
@@ -1012,6 +1017,7 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id);

View File

@@ -19,69 +19,97 @@ __global__ void limit_thinking_content_length_kernel_v1(
int64_t *next_tokens,
const int *max_think_lens,
const int64_t *step_idx,
const int64_t *eos_token_ids,
int *limit_think_status,
bool *stop_flags,
const int64_t think_end_id,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int bs,
const int eos_token_id_len) {
int bid = threadIdx.x;
if (bid >= bs) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 2) {
return;
}
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
if (current_limit_think_status == 2 && stop_flags[bid]) {
return;
}
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 1;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 1;
} else {
// 检查是否生成了EOS
for (int i = 0; i < eos_token_id_len; i++) {
if (eos_token_ids[i] == next_token) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 1;
if (stop_flags[bid]) {
stop_flags[bid] = false;
}
break;
}
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 think_end_id
// 2. status == 1: 上一阶段强制注入了 think_end_id
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 think_end_id
// 2. status == 1: 上一阶段强制注入了 think_end_id
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}
void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const paddle::Tensor &stop_flags,
const paddle::Tensor &eos_token_ids,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
think_end_id,
batch_size);
const int batch_size = next_tokens.shape()[0];
const int eos_token_id_len = eos_token_ids.shape()[0];
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
eos_token_ids.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
think_end_id,
batch_size,
eos_token_id_len);
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -24,87 +24,94 @@ __global__ void limit_thinking_content_length_kernel_v2(
const int *max_think_lens,
const int64_t *step_idx,
int *limit_think_status,
const bool *stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}
int bid = threadIdx.x;
if (bid >= bs) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
if (current_limit_think_status == 3 && stop_flags[bid]) {
return;
}
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
}
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}
void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const paddle::Tensor &stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
think_end_id,
line_break_id,
batch_size);
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
batch_size);
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -19,81 +19,98 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int64_t think_end_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int bs,
const int eos_token_id_len) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行
if (current_limit_think_status == 2 && stop_flags[bid]) {
return;
}
int new_accept_num = original_accept_num;
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else {
// 检查是否生成了EOS
for (int i = 0; i < eos_token_id_len; i++) {
if (eos_token_ids[i] == next_token) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
if (stop_flags[bid]) {
stop_flags[bid] = false;
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV1(
@@ -103,20 +120,26 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int eos_token_id_len = eos_token_ids.shape()[0];
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
tokens_per_step,
batch_size);
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<bool*>(stop_flags.data<bool>()),
think_end_id,
tokens_per_step,
batch_size,
eos_token_id_len);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
@@ -125,7 +148,9 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
"seq_lens_decoder",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -26,99 +26,100 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3 && stop_flags[bid]) {
return;
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
condition_triggered = true; // 因为修改了token需要截断
}
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
condition_triggered = true; // 因为修改了token需要截断
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
next_tokens[token_idx] = next_token;
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV2(
@@ -128,22 +129,24 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
line_break_id,
tokens_per_step,
batch_size);
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
tokens_per_step,
batch_size);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
@@ -152,7 +155,8 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
"seq_lens_decoder",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -97,6 +97,8 @@ def limit_thinking_content_length(
max_think_lens: paddle.Tensor,
step_idx: paddle.Tensor,
limit_think_status: paddle.Tensor,
stop_flags: paddle.Tensor,
eos_token_ids: paddle.Tensor,
think_end_id: int,
line_break_id: int = None,
):
@@ -107,6 +109,8 @@ def limit_thinking_content_length(
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
@@ -117,6 +121,7 @@ def limit_thinking_content_length(
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
)
@@ -132,6 +137,8 @@ def speculate_limit_thinking_content_length(
limit_think_status: paddle.Tensor,
accept_num: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
stop_flags: paddle.Tensor,
eos_token_ids: paddle.Tensor,
think_end_id: int,
line_break_id: int = None,
):
@@ -144,6 +151,8 @@ def speculate_limit_thinking_content_length(
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
@@ -156,6 +165,7 @@ def speculate_limit_thinking_content_length(
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -271,6 +281,8 @@ def post_process_normal(
max_think_lens=share_inputs["max_think_lens"],
step_idx=share_inputs["step_idx"],
limit_think_status=share_inputs["limit_think_status"],
stop_flags=share_inputs["stop_flags"],
eos_token_ids=share_inputs["eos_token_id"],
think_end_id=think_end_id,
line_break_id=line_break_id,
)

View File

@@ -33,10 +33,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: tokens unchanged, status unchanged
assert next_tokens.numpy()[0, 0] == 100
@@ -50,10 +54,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5, 8], dtype="int32")
step_idx = paddle.to_tensor([[5], [10]], dtype="int64") # Both exceed or equal limit
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: tokens replaced with think_end_id, status changed to 2
assert next_tokens.numpy()[0, 0] == 999 # Replaced
@@ -67,10 +75,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10], dtype="int32")
step_idx = paddle.to_tensor([[3]], dtype="int64") # Still within limit
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: token unchanged (already think_end_id), status changed to 2
assert next_tokens.numpy()[0, 0] == 999
@@ -82,10 +94,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[6]], dtype="int64")
limit_think_status = paddle.to_tensor([1], dtype="int32") # Status is 1
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: status changed to 2
assert limit_think_status.numpy()[0] == 2
@@ -96,10 +112,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([-1], dtype="int32") # Disabled
step_idx = paddle.to_tensor([[100]], dtype="int64") # Would exceed limit if enabled
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: nothing changed
assert next_tokens.numpy()[0, 0] == 100
@@ -111,10 +131,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[10]], dtype="int64")
limit_think_status = paddle.to_tensor([2], dtype="int32") # Already in response phase
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify: nothing changed
assert next_tokens.numpy()[0, 0] == 100
@@ -126,10 +150,14 @@ class TestLimitThinkingContentLengthV1(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 5, 8, -1], dtype="int32")
step_idx = paddle.to_tensor([[3], [5], [4], [100]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0, 0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
limit_thinking_content_length_v1(next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id)
limit_thinking_content_length_v1(
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, eos_token_ids, think_end_id
)
# Verify each sequence
# Seq 0: step 3 < max 10, status 0, token unchanged
@@ -158,12 +186,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 15], dtype="int32")
step_idx = paddle.to_tensor([[5], [8]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: tokens unchanged, status unchanged
@@ -179,11 +208,12 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[5]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 1
@@ -194,7 +224,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 999 # think_end_id
assert limit_think_status.numpy()[0] == 1
@@ -205,7 +235,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 1
@@ -216,7 +246,7 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([1], dtype="int32")
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
assert next_tokens.numpy()[0, 0] == 888 # line_break_id
assert limit_think_status.numpy()[0] == 3 # Move to status 3
@@ -227,12 +257,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10], dtype="int32")
step_idx = paddle.to_tensor([[3]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: status changed to 3 (response phase)
@@ -245,12 +276,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[9]], dtype="int64")
limit_think_status = paddle.to_tensor([2], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: status changed to 3
@@ -262,12 +294,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([-1], dtype="int32")
step_idx = paddle.to_tensor([[100]], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: nothing changed
@@ -280,12 +313,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([5], dtype="int32")
step_idx = paddle.to_tensor([[10]], dtype="int64")
limit_think_status = paddle.to_tensor([3], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Verify: nothing changed
@@ -298,12 +332,13 @@ class TestLimitThinkingContentLengthV2(unittest.TestCase):
max_think_lens = paddle.to_tensor([10, 5, 8, -1, 6], dtype="int32")
step_idx = paddle.to_tensor([[3], [5], [4], [100], [9]], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0, 0, 0, 2], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False, False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
# Run operator
limit_thinking_content_length_v2(
next_tokens, max_think_lens, step_idx, limit_think_status, think_end_id, line_break_id
next_tokens, max_think_lens, step_idx, limit_think_status, stop_flags, think_end_id, line_break_id
)
# Seq 0: step 3 < max 10, status 0, unchanged

View File

@@ -37,6 +37,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5, 8], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -47,6 +49,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -72,6 +76,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([4], dtype="int32")
seq_lens_decoder = paddle.to_tensor([12], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -82,6 +88,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -103,6 +111,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -113,6 +123,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -129,6 +141,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([100], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -139,6 +153,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -155,6 +171,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([0], dtype="int32") # No tokens accepted
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -165,6 +183,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -180,6 +200,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([3], dtype="int32") # Terminal status
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -190,6 +212,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -205,6 +229,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([9], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
speculate_limit_thinking_content_length_v1(
@@ -214,6 +240,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -231,6 +259,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status = paddle.to_tensor([0, 0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([6, 8, 50], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
# Run operator
@@ -241,6 +271,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
)
@@ -271,6 +303,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5, 8], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -282,6 +315,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -302,6 +336,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([5], dtype="int32")
seq_lens_decoder = paddle.to_tensor([12], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -313,6 +348,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -336,6 +372,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
speculate_limit_thinking_content_length_v2(
next_tokens,
@@ -344,6 +381,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -364,6 +402,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -384,6 +423,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -404,6 +444,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -418,6 +459,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -429,6 +471,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -444,6 +487,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([2], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -455,6 +499,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -470,6 +515,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([100], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -481,6 +527,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -497,6 +544,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -508,6 +556,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)
@@ -524,6 +573,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status = paddle.to_tensor([3], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -535,6 +585,7 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
)