[XPU] xpu support think length limit (#4539)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* [XPU] xpu support think length limit

* [XPU] xpu c++ code files format

---------

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-10-23 15:58:11 +08:00
committed by GitHub
parent 2676a918f0
commit 5443b2cffb
8 changed files with 538 additions and 0 deletions

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const int64_t think_end_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int batch_size = next_tokens.shape()[0];
int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1(
xpu_ctx->x_context(),
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
think_end_id,
batch_size);
PD_CHECK(r == 0,
"baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1 "
"failed.");
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV1));

View File

@@ -0,0 +1,54 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "paddle/phi/core/enforce.h"
#include "xpu/plugin.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif
void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int batch_size = next_tokens.shape()[0];
int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v2(
xpu_ctx->x_context(),
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
think_end_id,
line_break_id,
batch_size);
PD_CHECK(r == 0,
"baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v2 "
"failed.");
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV2));

View File

@@ -215,6 +215,25 @@ DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
int64_t hidden_size,
bool is_scatter);
DLL_EXPORT int limit_thinking_content_length_kernel_v1(
api::Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int bs);
DLL_EXPORT int limit_thinking_content_length_kernel_v2(
api::Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs);
/*--------------------------------------- MTP being
* --------------------------------------------*/

View File

@@ -0,0 +1,73 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
// #include <stdio.h>
// using namespace std;
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int bs) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
if (clusterid != 0) return;
for (int i = cid; i < bs; i += ncores) {
int max_think_len_lm;
int limit_think_status_lm;
int64_t next_token_lm;
int64_t step_idx_lm;
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
if (max_think_len_lm < 0) continue;
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (limit_think_status_lm == 2) continue;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (limit_think_status_lm < 1) {
// 当开启思考长度控制时,检查是否超时
if (step_idx_lm >= max_think_len_lm) {
// 强制将当前token替换为结束思考的token
next_token_lm = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
limit_think_status_lm = 1;
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 think_end_id
// 2. status == 1: 上一阶段强制注入了 think_end_id
if (limit_think_status_lm < 2) {
if (next_token_lm == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
limit_think_status_lm = 2;
}
}
// 写回更新后的 token
LM2GM_ASYNC(&next_token_lm, next_tokens + i, sizeof(int64_t));
// 更新全局状态
LM2GM(&limit_think_status_lm, limit_think_status + i, sizeof(int));
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,90 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
// #include <stdio.h>
// using namespace std;
#include "xpu/kernel/xtdk.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu3 {
namespace plugin {
__global__ void limit_thinking_content_length_kernel_v2(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
if (clusterid != 0) return;
for (int i = cid; i < bs; i += ncores) {
int max_think_len_lm;
int limit_think_status_lm;
int64_t next_token_lm;
int64_t step_idx_lm;
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
if (max_think_len_lm < 0) continue;
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (limit_think_status_lm == 3) continue;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (limit_think_status_lm <= 1) {
// 当开启思考长度控制时,检查是否超时
if (step_idx_lm == max_think_len_lm) {
// 强制将当前token替换为结束思考的token
next_token_lm = line_break_id;
limit_think_status_lm = 1;
} else if (step_idx_lm == max_think_len_lm + 1) {
// 强制将当前token替换为结束思考的token
next_token_lm = think_end_id;
limit_think_status_lm = 1;
} else if (step_idx_lm == max_think_len_lm + 2) {
// 强制将当前token替换为结束思考的token
next_token_lm = line_break_id;
limit_think_status_lm = 1;
} else if (step_idx_lm == max_think_len_lm + 2) {
// 强制将当前token替换为结束思考的token
next_token_lm = line_break_id;
limit_think_status_lm = 2;
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (limit_think_status_lm == 0) {
if (next_token_lm == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
limit_think_status_lm = 3;
}
}
if (limit_think_status_lm == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
limit_think_status_lm = 3;
}
// 写回更新后的 token
LM2GM_ASYNC(&next_token_lm, next_tokens + i, sizeof(int64_t));
// 更新全局状态
LM2GM(&limit_think_status_lm, limit_think_status + i, sizeof(int));
}
}
} // namespace plugin
} // namespace xpu3

View File

@@ -0,0 +1,95 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int bs);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int bs) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto limit_thinking_content_length_kernel_v1 =
xpu3::plugin::limit_thinking_content_length_kernel_v1;
limit_thinking_content_length_kernel_v1<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
limit_think_status,
think_end_id,
bs);
return api::SUCCESS;
}
int limit_thinking_content_length_kernel_v1(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int bs) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v1", int);
WRAPPER_DUMP_PARAM5(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
think_end_id);
WRAPPER_DUMP_PARAM1(ctx, bs);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
bs);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <numeric>
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu3 {
namespace plugin {
__attribute__((global)) void limit_thinking_content_length_kernel_v2(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs);
} // namespace plugin
} // namespace xpu3
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto limit_thinking_content_length_kernel_v2 =
xpu3::plugin::limit_thinking_content_length_kernel_v2;
limit_thinking_content_length_kernel_v2<<<1, 64, ctx->xpu_stream>>>(
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
limit_think_status,
think_end_id,
line_break_id,
bs);
return api::SUCCESS;
}
int limit_thinking_content_length_kernel_v2(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v2", int);
WRAPPER_DUMP_PARAM5(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
think_end_id);
WRAPPER_DUMP_PARAM2(ctx, line_break_id, bs);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
line_break_id,
bs);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu

View File

@@ -45,6 +45,8 @@ from fastdeploy.model_executor.ops.xpu import (
adjust_batch,
get_infer_param,
get_padding_offset,
limit_thinking_content_length_v1,
limit_thinking_content_length_v2,
recover_decode_task,
set_data_ipc,
share_external_data,
@@ -185,6 +187,8 @@ def xpu_post_process(
share_inputs: Dict[str, paddle.Tensor],
block_size: int = 64,
skip_save_output: bool = False,
think_end_id: int = None,
line_break_id: int = None,
) -> None:
""" """
from fastdeploy.model_executor.ops.xpu import (
@@ -193,6 +197,34 @@ def xpu_post_process(
update_inputs,
)
if think_end_id > 0:
limit_strategy = envs.FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR
max_think_lens = share_inputs["max_think_lens"]
step_idx = share_inputs["step_idx"]
limit_think_status = share_inputs["limit_think_status"]
if limit_strategy == "</think>":
# for ernie4_5_vl
limit_thinking_content_length_v1(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
# for ernie_x1
assert line_break_id > 0
limit_thinking_content_length_v2(
sampled_token_ids,
max_think_lens,
step_idx,
limit_think_status,
think_end_id,
line_break_id,
)
else:
raise NotImplementedError(f"Not support {limit_strategy=} for limit thinking content length.")
# 1. Set stop value
paddle.assign(
paddle.where(
@@ -431,6 +463,15 @@ class XPUModelRunner(ModelRunnerBase):
position_ids, request.get("max_tokens", 2048)
)
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
else:
# Disable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
if len(request.output_token_ids) == 0:
input_ids = request.prompt_token_ids
else:
@@ -566,6 +607,15 @@ class XPUModelRunner(ModelRunnerBase):
)
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None:
# Enable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = request.get("reasoning_max_tokens")
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
else:
# Disable thinking
self.share_inputs["max_think_lens"][idx : idx + 1, :] = -1
self.share_inputs["limit_think_status"][idx : idx + 1, :] = 0
def get_attr_from_request(request, attr, default_value=None):
res = request.get(attr, default_value)
if res is not None:
@@ -712,6 +762,10 @@ class XPUModelRunner(ModelRunnerBase):
self.share_inputs["system_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["system_ids"] = paddle.full([max_num_seqs, 1], -1, dtype="int32")
# Initialize thinking related buffers
self.share_inputs["max_think_lens"] = paddle.full(shape=[max_num_seqs, 1], fill_value=-1, dtype="int32")
self.share_inputs["limit_think_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
@@ -1111,6 +1165,8 @@ class XPUModelRunner(ModelRunnerBase):
share_inputs=self.share_inputs,
block_size=self.cache_config.block_size,
skip_save_output=is_dummy_run,
think_end_id=self.model_config.think_end_id,
line_break_id=self.model_config.line_break_id,
)
# 7. Updata 'infer_seed' and step_paddle()