[XPU] fix thinking bug where output only contains reasoning_content (#4760)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-11-04 12:47:34 +08:00
committed by GitHub
parent ffa57dbfac
commit 78a1451eb7
8 changed files with 201 additions and 21 deletions

View File

@@ -25,27 +25,38 @@ void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
const int batch_size = next_tokens.shape()[0];
const int eos_token_id_len = eos_token_ids.shape()[0];
int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1(
xpu_ctx->x_context(),
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<bool*>(stop_flags.data<bool>()),
think_end_id,
batch_size);
batch_size,
eos_token_id_len);
PD_CHECK(r == 0,
"baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1 "
"failed.");
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -25,6 +25,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
@@ -38,6 +39,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
batch_size);
@@ -47,7 +49,11 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
}
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})

View File

@@ -220,9 +220,12 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs);
const int bs,
const int eos_token_id_len);
DLL_EXPORT int limit_thinking_content_length_kernel_v2(
api::Context* ctx,
@@ -230,6 +233,7 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v2(
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs);

View File

@@ -10,43 +10,68 @@
namespace xpu3 {
namespace plugin {
template <typename T>
static inline __device__ bool is_in_end(const T id,
const T* end_ids,
const int length) {
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return false;
}
__global__ void limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs) {
const int bs,
const int eos_token_id_len) {
int cid = core_id();
int ncores = core_num();
int clusterid = cluster_id();
int nclusters = cluster_num();
if (clusterid != 0) return;
__simd__ __local__ int64_t eos_token_ids_lm[256];
for (int i = cid; i < bs; i += ncores) {
int max_think_len_lm;
int limit_think_status_lm;
int64_t next_token_lm;
int64_t step_idx_lm;
bool stop_flags_lm;
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
GM2LM_ASYNC(
eos_token_ids, eos_token_ids_lm, sizeof(int64_t) * eos_token_id_len);
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
if (max_think_len_lm < 0) continue;
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (limit_think_status_lm == 2) continue;
if (limit_think_status_lm == 2 && stop_flags_lm) continue;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (limit_think_status_lm < 1) {
// 当开启思考长度控制时,检查是否超时
if (step_idx_lm >= max_think_len_lm) {
if ((step_idx_lm >= max_think_len_lm) ||
is_in_end(next_token_lm, eos_token_ids_lm, eos_token_id_len)) {
// 强制将当前token替换为结束思考的token
next_token_lm = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
limit_think_status_lm = 1;
if (stop_flags_lm) {
stop_flags_lm = false;
LM2GM(&stop_flags_lm, stop_flags + i, sizeof(bool));
}
}
}

View File

@@ -15,6 +15,7 @@ __global__ void limit_thinking_content_length_kernel_v2(
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
@@ -29,15 +30,17 @@ __global__ void limit_thinking_content_length_kernel_v2(
int limit_think_status_lm;
int64_t next_token_lm;
int64_t step_idx_lm;
bool stop_flags_lm;
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
// 如果该序列未启用思考功能,则直接返回,默认值为 -1表示不限制思考长度
if (max_think_len_lm < 0) continue;
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (limit_think_status_lm == 3) continue;
if (limit_think_status_lm == 3 && stop_flags_lm) continue;
// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束

View File

@@ -24,10 +24,12 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs);
const int bs,
const int eos_token_id_len);
} // namespace plugin
} // namespace xpu3
@@ -36,13 +38,58 @@ namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs,
const int eos_token_id_len) {
auto is_in_end = [](int64_t token_id, const int64_t* end_ids, int length) {
for (int i = 0; i < length; i++) {
if (token_id == end_ids[i]) {
return true;
}
}
return false;
};
for (int bid = 0; bid < bs; bid++) {
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) continue;
int current_limit_think_status = limit_think_status[bid];
if (limit_think_status[bid] == 2 && stop_flags[bid]) continue;
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
if (current_limit_think_status < 1) {
if (step >= max_think_len ||
is_in_end(next_token, eos_token_ids, eos_token_id_len)) {
next_token = think_end_id;
current_limit_think_status = 1;
}
}
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
current_limit_think_status = 2;
}
}
next_tokens[bid] = next_token;
limit_think_status[bid] = current_limit_think_status;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs) {
const int bs,
const int eos_token_id_len) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto limit_thinking_content_length_kernel_v1 =
xpu3::plugin::limit_thinking_content_length_kernel_v1;
@@ -50,9 +97,12 @@ static int xpu3_wrapper(Context* ctx,
reinterpret_cast<XPU_INT64*>(next_tokens),
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
limit_think_status,
stop_flags,
think_end_id,
bs);
bs,
eos_token_id_len);
return api::SUCCESS;
}
@@ -60,31 +110,46 @@ int limit_thinking_content_length_kernel_v1(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
const int64_t* eos_token_ids,
int* limit_think_status,
bool* stop_flags,
const int64_t think_end_id,
const int bs) {
const int bs,
const int eos_token_id_len) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v1", int);
WRAPPER_DUMP_PARAM5(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
think_end_id);
WRAPPER_DUMP_PARAM1(ctx, bs);
eos_token_ids,
limit_think_status);
WRAPPER_DUMP_PARAM4(ctx, stop_flags, think_end_id, bs, eos_token_id_len);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
return cpu_wrapper(ctx,
next_tokens,
max_think_lens,
step_idx,
eos_token_ids,
limit_think_status,
stop_flags,
think_end_id,
bs,
eos_token_id_len);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
next_tokens,
max_think_lens,
step_idx,
eos_token_ids,
limit_think_status,
stop_flags,
think_end_id,
bs);
bs,
eos_token_id_len);
}
WRAPPER_UNIMPLEMENTED(ctx);
}

View File

@@ -25,6 +25,7 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v2(
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs);
@@ -37,11 +38,60 @@ namespace xpu {
namespace api {
namespace plugin {
static int cpu_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
for (int bid = 0; bid < bs; bid++) {
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) continue;
int current_limit_think_status = limit_think_status[bid];
if (current_limit_think_status == 3 && stop_flags[bid]) {
continue;
}
int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];
if (current_limit_think_status <= 1) {
if (step == max_think_len) {
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 1) {
next_token = think_end_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 2) {
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 3) {
next_token = line_break_id;
current_limit_think_status = 2;
}
}
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
current_limit_think_status = 3;
}
next_tokens[bid] = next_token;
limit_think_status[bid] = current_limit_think_status;
}
return api::SUCCESS;
}
static int xpu3_wrapper(Context* ctx,
int64_t* next_tokens,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
@@ -53,6 +103,7 @@ static int xpu3_wrapper(Context* ctx,
max_think_lens,
reinterpret_cast<const XPU_INT64*>(step_idx),
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
bs);
@@ -64,6 +115,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
const int* max_think_lens,
const int64_t* step_idx,
int* limit_think_status,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
@@ -74,11 +126,19 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
max_think_lens,
step_idx,
limit_think_status,
think_end_id);
WRAPPER_DUMP_PARAM2(ctx, line_break_id, bs);
stop_flags);
WRAPPER_DUMP_PARAM3(ctx, think_end_id, line_break_id, bs);
WRAPPER_DUMP(ctx);
if (ctx->dev().type() == api::kCPU) {
assert(false);
return cpu_wrapper(ctx,
next_tokens,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
bs);
}
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
return xpu3_wrapper(ctx,
@@ -86,6 +146,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
bs);

View File

@@ -202,6 +202,8 @@ def xpu_post_process(
max_think_lens = share_inputs["max_think_lens"]
step_idx = share_inputs["step_idx"]
limit_think_status = share_inputs["limit_think_status"]
stop_flags = share_inputs["stop_flags"]
eos_token_ids = share_inputs["eos_token_id"]
if limit_strategy == "</think>":
# for ernie-45-vl
limit_thinking_content_length_v1(
@@ -209,6 +211,8 @@ def xpu_post_process(
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
)
elif limit_strategy == "\n</think>\n\n":
@@ -219,6 +223,7 @@ def xpu_post_process(
max_think_lens,
step_idx,
limit_think_status,
stop_flags,
think_end_id,
line_break_id,
)