Optimizing the performance of think length limit using custom operators (#4279)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

* delete impl

* delete min_length&max_length

* support limit thinking content strategy

* fix

* fix

* fix

* update

* fix set_value_by_flags_and_idx

* fix

* fix

* fix

* fix

* update

* fix

* fix

* fix typo

* fix ci

* fix

* fix

* support mtp

* fix

* fix

* update

* update
This commit is contained in:
Yuanle Liu
2025-10-20 21:09:13 +08:00
committed by GitHub
parent 36af88ff3f
commit cef3164c3b
31 changed files with 747 additions and 1032 deletions

View File

@@ -0,0 +1,132 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
__global__ void speculate_limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
int64_t* step_idx,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const int64_t think_end_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
tokens_per_step,
batch_size);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV1));

View File

@@ -0,0 +1,159 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/extension.h"
// status == 0: 正常生成阶段
// status == 1: 替换阶段
// status == 2: 替换结束阶段
// status == 3: 思考结束阶段
__global__ void speculate_limit_thinking_content_length_kernel_v2(
int64_t* next_tokens,
const int* max_think_lens,
int64_t* step_idx,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const int64_t think_end_id,
const int64_t line_break_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}
int new_accept_num = original_accept_num;
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;
bool condition_triggered = false;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token需要截断
} else if (current_step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
condition_triggered = true; // 因为修改了token需要截断
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
next_tokens[token_idx] = next_token;
if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}
// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}
void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];
speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
line_break_id,
tokens_per_step,
batch_size);
}
PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV2));

View File

@@ -38,7 +38,7 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all,
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped
if (step_idx[tid] >= 0) {
if (step_idx[tid] > 0) {
for (int i = 0; i < accept_num[tid]; i++) {
pre_ids_all_now[step_idx[tid] - i] =
accept_tokens_now[accept_num[tid] - 1 - i];