From 83f97d1196904b4792c62bf862403750abcffa65 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 16 Oct 2025 13:23:16 +0800 Subject: [PATCH] support speculate_limit_thinking_content_length_v2 (#4428) * support speculate_limit_thinking_content_length_v2 * fix * fix import --- custom_ops/gpu_ops/cpp_extensions.cc | 12 ++ ...culate_limit_thinking_content_length_v2.cu | 159 ++++++++++++++++++ .../speculate_set_value_by_flags.cu | 2 +- .../model_executor/layers/sample/sampler.py | 25 ++- 4 files changed, 193 insertions(+), 5 deletions(-) create mode 100644 custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 272e6b2b1..f7228bcfe 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -678,6 +678,16 @@ void SpeculateVerify( const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts); +void SpeculateLimitThinkingContentLengthV2( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id); + void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor ¬_need_stop, @@ -1245,6 +1255,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_verify",&SpeculateVerify, "speculate_verify function"); + m.def("speculate_limit_thinking_content_length_v2",&SpeculateLimitThinkingContentLengthV2, "speculate limit thinking content length function"); + m.def("speculate_update",&SpeculateUpdate, "Speculate Update Kernel"); m.def("speculate_set_value_by_flags_and_idx",&SpeculateSetValueByFlagsAndIdx, "speculate_set_value_by_flags_and_idx function"); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu new file mode 100644 index 000000000..e885cfb2a --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length_v2.cu @@ -0,0 +1,159 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" + +// status == 0: 正常生成阶段 +// status == 1: 替换阶段 +// status == 2: 替换结束阶段 +// status == 3: 思考结束阶段 +__global__ void speculate_limit_thinking_content_length_kernel_v2( + int64_t* next_tokens, + const int* max_think_lens, + int64_t* step_idx, + int* limit_think_status, + int* accept_num, + int* seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id, + const int tokens_per_step, + const int bs) { + int bid = threadIdx.x; + if (bid >= bs) return; + + const int original_accept_num = accept_num[bid]; + if (original_accept_num <= 0) return; + + // 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度 + const int max_think_len = max_think_lens[bid]; + if (max_think_len < 0) return; + int current_limit_think_status = limit_think_status[bid]; + // 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行. + if (current_limit_think_status == 3) { + return; + } + + int new_accept_num = original_accept_num; + + const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + + for (int token_offset = 0; token_offset < original_accept_num; + token_offset++) { + const int token_idx = bid * tokens_per_step + token_offset; + int64_t next_token = next_tokens[token_idx]; + const int64_t current_step = current_base_step + token_offset; + + bool condition_triggered = false; + + // ======================= 思考阶段控制 ======================= + // 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束 + // 阶段 2: 在替换 (status == 1), 检查是否替换结束 + if (current_limit_think_status <= 1) { + // 当开启思考长度控制时,检查是否超时 + if (current_step == max_think_len) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 1) { + // 强制将当前token替换为结束思考的token + next_token = think_end_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 2) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + current_limit_think_status = 1; + condition_triggered = true; // 因为修改了token,需要截断 + } else if (current_step == max_think_len + 3) { + // 强制将当前token替换为结束思考的token + next_token = line_break_id; + // 将状态推进到 1, 表示 "正在结束思考" + current_limit_think_status = 2; + condition_triggered = true; // 因为修改了token,需要截断 + } + } + + // ======================= 思考结束处理 ======================= + // 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2) + // 这种情况会处理两种场景: + // 1. status == 0: 模型可能自己生成了 + // 2. status == 2: 上一阶段强制注入了 \n\n\n + if (current_limit_think_status == 0) { + if (next_token == think_end_id) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + } + if (current_limit_think_status == 2) { + // 确认思考结束,将状态推进到 3 (响应阶段) + current_limit_think_status = 3; + } + + next_tokens[token_idx] = next_token; + + if (condition_triggered) { + new_accept_num = token_offset + 1; + break; + } + } + + // 更新全局状态 + int discarded_tokens = original_accept_num - new_accept_num; + if (discarded_tokens > 0) { + step_idx[bid] -= discarded_tokens; + seq_lens_decoder[bid] -= discarded_tokens; + } + + accept_num[bid] = new_accept_num; + limit_think_status[bid] = current_limit_think_status; +} + +void SpeculateLimitThinkingContentLengthV2( + const paddle::Tensor& next_tokens, + const paddle::Tensor& max_think_lens, + const paddle::Tensor& step_idx, + const paddle::Tensor& limit_think_status, + const paddle::Tensor& accept_num, + const paddle::Tensor& seq_lens_decoder, + const int64_t think_end_id, + const int64_t line_break_id) { + const int batch_size = next_tokens.shape()[0]; + const int tokens_per_step = next_tokens.shape()[1]; + + speculate_limit_thinking_content_length_kernel_v2<<<1, 1024>>>( + const_cast(next_tokens.data()), + max_think_lens.data(), + const_cast(step_idx.data()), + const_cast(limit_think_status.data()), + const_cast(accept_num.data()), + const_cast(seq_lens_decoder.data()), + think_end_id, + line_break_id, + tokens_per_step, + batch_size); +} + +PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2) + .Inputs({"next_tokens", + "max_think_lens", + "step_idx", + "limit_think_status", + "accept_num", + "seq_lens_decoder"}) + .Attrs({"think_end_id: int64_t", "line_break_id: int64_t"}) + .Outputs({"next_tokens_out"}) + .SetInplaceMap({{"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV2)); diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu index dae1d40fc..316604c73 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_value_by_flags.cu @@ -37,7 +37,7 @@ __global__ void speculate_set_value_by_flag_and_id(int64_t *pre_ids_all, const int seq_len_dec = seq_lens_decoder[tid]; const int seq_len_enc = seq_lens_encoder[tid]; if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped - if (step_idx[tid] >= 0) { + if (step_idx[tid] > 0) { for (int i = 0; i < accept_num[tid]; i++) { pre_ids_all_now[step_idx[tid] - i] = accept_tokens_now[accept_num[tid] - 1 - i]; diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 21c40295a..6a8db178f 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -39,6 +39,13 @@ from fastdeploy.model_executor.layers.sample.ops import ( from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + speculate_limit_thinking_content_length_v2, + speculate_verify, + top_p_candidates, + ) + def top_p_normalize_probs_paddle( probs: paddle.Tensor, @@ -396,11 +403,9 @@ class SpeculativeSampler(nn.Layer): max_model_len: int, share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, + think_end_id: int = -1, + line_break_id: int = -1, ) -> paddle.Tensor: - """ """ - - from fastdeploy.model_executor.ops.gpu import speculate_verify, top_p_candidates - logits = apply_speculative_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, @@ -455,6 +460,18 @@ class SpeculativeSampler(nn.Layer): accept_all_drafts, ) + if think_end_id > 0 and line_break_id > 0: + speculate_limit_thinking_content_length_v2( + share_inputs["accept_tokens"], + share_inputs["max_think_lens"], + share_inputs["step_idx"], + share_inputs["limit_think_status"], + share_inputs["accept_num"], + share_inputs["seq_lens_decoder"], + think_end_id, + line_break_id, + ) + return None