mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] fix thinking bug where output only contains reasoning_content (#4760)
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -25,27 +25,38 @@ void LimitThinkingContentLengthV1(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const paddle::Tensor& eos_token_ids,
|
||||
const int64_t think_end_id) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
|
||||
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
|
||||
|
||||
const int batch_size = next_tokens.shape()[0];
|
||||
const int eos_token_id_len = eos_token_ids.shape()[0];
|
||||
int r = baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1(
|
||||
xpu_ctx->x_context(),
|
||||
const_cast<int64_t*>(next_tokens.data<int64_t>()),
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
eos_token_ids.data<int64_t>(),
|
||||
const_cast<int*>(limit_think_status.data<int>()),
|
||||
const_cast<bool*>(stop_flags.data<bool>()),
|
||||
think_end_id,
|
||||
batch_size);
|
||||
batch_size,
|
||||
eos_token_id_len);
|
||||
PD_CHECK(r == 0,
|
||||
"baidu::xpu::api::plugin::limit_thinking_content_length_kernel_v1 "
|
||||
"failed.");
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(limit_thinking_content_length_v1)
|
||||
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"stop_flags",
|
||||
"eos_token_ids"})
|
||||
.Attrs({"think_end_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
|
||||
@@ -25,6 +25,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
|
||||
const paddle::Tensor& max_think_lens,
|
||||
const paddle::Tensor& step_idx,
|
||||
const paddle::Tensor& limit_think_status,
|
||||
const paddle::Tensor& stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id) {
|
||||
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
|
||||
@@ -38,6 +39,7 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
|
||||
max_think_lens.data<int>(),
|
||||
step_idx.data<int64_t>(),
|
||||
const_cast<int*>(limit_think_status.data<int>()),
|
||||
stop_flags.data<bool>(),
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
batch_size);
|
||||
@@ -47,7 +49,11 @@ void LimitThinkingContentLengthV2(const paddle::Tensor& next_tokens,
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(limit_thinking_content_length_v2)
|
||||
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
|
||||
.Inputs({"next_tokens",
|
||||
"max_think_lens",
|
||||
"step_idx",
|
||||
"limit_think_status",
|
||||
"stop_flags"})
|
||||
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
|
||||
.Outputs({"next_tokens_out"})
|
||||
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
|
||||
|
||||
@@ -220,9 +220,12 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v1(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs);
|
||||
const int bs,
|
||||
const int eos_token_id_len);
|
||||
|
||||
DLL_EXPORT int limit_thinking_content_length_kernel_v2(
|
||||
api::Context* ctx,
|
||||
@@ -230,6 +233,7 @@ DLL_EXPORT int limit_thinking_content_length_kernel_v2(
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs);
|
||||
|
||||
@@ -10,43 +10,68 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
template <typename T>
|
||||
static inline __device__ bool is_in_end(const T id,
|
||||
const T* end_ids,
|
||||
const int length) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
__global__ void limit_thinking_content_length_kernel_v1(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs) {
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
int clusterid = cluster_id();
|
||||
int nclusters = cluster_num();
|
||||
if (clusterid != 0) return;
|
||||
__simd__ __local__ int64_t eos_token_ids_lm[256];
|
||||
|
||||
for (int i = cid; i < bs; i += ncores) {
|
||||
int max_think_len_lm;
|
||||
int limit_think_status_lm;
|
||||
int64_t next_token_lm;
|
||||
int64_t step_idx_lm;
|
||||
bool stop_flags_lm;
|
||||
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
|
||||
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
|
||||
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
|
||||
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
|
||||
GM2LM_ASYNC(
|
||||
eos_token_ids, eos_token_ids_lm, sizeof(int64_t) * eos_token_id_len);
|
||||
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
|
||||
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
if (max_think_len_lm < 0) continue;
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
|
||||
if (limit_think_status_lm == 2) continue;
|
||||
if (limit_think_status_lm == 2 && stop_flags_lm) continue;
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
if (limit_think_status_lm < 1) {
|
||||
// 当开启思考长度控制时,检查是否超时
|
||||
if (step_idx_lm >= max_think_len_lm) {
|
||||
if ((step_idx_lm >= max_think_len_lm) ||
|
||||
is_in_end(next_token_lm, eos_token_ids_lm, eos_token_id_len)) {
|
||||
// 强制将当前token替换为结束思考的token
|
||||
next_token_lm = think_end_id;
|
||||
// 将状态推进到 1, 表示 "正在结束思考"
|
||||
limit_think_status_lm = 1;
|
||||
if (stop_flags_lm) {
|
||||
stop_flags_lm = false;
|
||||
LM2GM(&stop_flags_lm, stop_flags + i, sizeof(bool));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ __global__ void limit_thinking_content_length_kernel_v2(
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs) {
|
||||
@@ -29,15 +30,17 @@ __global__ void limit_thinking_content_length_kernel_v2(
|
||||
int limit_think_status_lm;
|
||||
int64_t next_token_lm;
|
||||
int64_t step_idx_lm;
|
||||
bool stop_flags_lm;
|
||||
GM2LM_ASYNC(next_tokens + i, &next_token_lm, sizeof(int64_t));
|
||||
GM2LM_ASYNC(step_idx + i, &step_idx_lm, sizeof(int64_t));
|
||||
GM2LM_ASYNC(stop_flags + i, &stop_flags_lm, sizeof(bool));
|
||||
GM2LM_ASYNC(max_think_lens + i, &max_think_len_lm, sizeof(int));
|
||||
GM2LM(limit_think_status + i, &limit_think_status_lm, sizeof(int));
|
||||
|
||||
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
|
||||
if (max_think_len_lm < 0) continue;
|
||||
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
|
||||
if (limit_think_status_lm == 3) continue;
|
||||
if (limit_think_status_lm == 3 && stop_flags_lm) continue;
|
||||
|
||||
// ======================= 思考阶段控制 =======================
|
||||
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
|
||||
|
||||
@@ -24,10 +24,12 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v1(
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs);
|
||||
|
||||
const int bs,
|
||||
const int eos_token_id_len);
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
|
||||
@@ -36,13 +38,58 @@ namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
auto is_in_end = [](int64_t token_id, const int64_t* end_ids, int length) {
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (token_id == end_ids[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
for (int bid = 0; bid < bs; bid++) {
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) continue;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
if (limit_think_status[bid] == 2 && stop_flags[bid]) continue;
|
||||
int64_t next_token = next_tokens[bid];
|
||||
const int64_t step = step_idx[bid];
|
||||
if (current_limit_think_status < 1) {
|
||||
if (step >= max_think_len ||
|
||||
is_in_end(next_token, eos_token_ids, eos_token_id_len)) {
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
}
|
||||
}
|
||||
if (current_limit_think_status < 2) {
|
||||
if (next_token == think_end_id) {
|
||||
current_limit_think_status = 2;
|
||||
}
|
||||
}
|
||||
next_tokens[bid] = next_token;
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs) {
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto limit_thinking_content_length_kernel_v1 =
|
||||
xpu3::plugin::limit_thinking_content_length_kernel_v1;
|
||||
@@ -50,9 +97,12 @@ static int xpu3_wrapper(Context* ctx,
|
||||
reinterpret_cast<XPU_INT64*>(next_tokens),
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
reinterpret_cast<const XPU_INT64*>(eos_token_ids),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
bs);
|
||||
bs,
|
||||
eos_token_id_len);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
@@ -60,31 +110,46 @@ int limit_thinking_content_length_kernel_v1(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
const int64_t* eos_token_ids,
|
||||
int* limit_think_status,
|
||||
bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int bs) {
|
||||
const int bs,
|
||||
const int eos_token_id_len) {
|
||||
WRAPPER_CHECK_CTX(ctx);
|
||||
WRAPPER_DUMP_FUNCTION_T1(ctx, "limit_thinking_content_length_kernel_v1", int);
|
||||
WRAPPER_DUMP_PARAM5(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
think_end_id);
|
||||
WRAPPER_DUMP_PARAM1(ctx, bs);
|
||||
eos_token_ids,
|
||||
limit_think_status);
|
||||
WRAPPER_DUMP_PARAM4(ctx, stop_flags, think_end_id, bs, eos_token_id_len);
|
||||
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
return cpu_wrapper(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
eos_token_ids,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
bs,
|
||||
eos_token_id_len);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
eos_token_ids,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
bs);
|
||||
bs,
|
||||
eos_token_id_len);
|
||||
}
|
||||
WRAPPER_UNIMPLEMENTED(ctx);
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ __attribute__((global)) void limit_thinking_content_length_kernel_v2(
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs);
|
||||
@@ -37,11 +38,60 @@ namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int cpu_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs) {
|
||||
for (int bid = 0; bid < bs; bid++) {
|
||||
const int max_think_len = max_think_lens[bid];
|
||||
if (max_think_len < 0) continue;
|
||||
int current_limit_think_status = limit_think_status[bid];
|
||||
if (current_limit_think_status == 3 && stop_flags[bid]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t next_token = next_tokens[bid];
|
||||
const int64_t step = step_idx[bid];
|
||||
|
||||
if (current_limit_think_status <= 1) {
|
||||
if (step == max_think_len) {
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 1) {
|
||||
next_token = think_end_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 2) {
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 1;
|
||||
} else if (step == max_think_len + 3) {
|
||||
next_token = line_break_id;
|
||||
current_limit_think_status = 2;
|
||||
}
|
||||
}
|
||||
if (current_limit_think_status == 0) {
|
||||
if (next_token == think_end_id) {
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
}
|
||||
if (current_limit_think_status == 2) {
|
||||
current_limit_think_status = 3;
|
||||
}
|
||||
next_tokens[bid] = next_token;
|
||||
limit_think_status[bid] = current_limit_think_status;
|
||||
}
|
||||
return api::SUCCESS;
|
||||
}
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
int64_t* next_tokens,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs) {
|
||||
@@ -53,6 +103,7 @@ static int xpu3_wrapper(Context* ctx,
|
||||
max_think_lens,
|
||||
reinterpret_cast<const XPU_INT64*>(step_idx),
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
bs);
|
||||
@@ -64,6 +115,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
|
||||
const int* max_think_lens,
|
||||
const int64_t* step_idx,
|
||||
int* limit_think_status,
|
||||
const bool* stop_flags,
|
||||
const int64_t think_end_id,
|
||||
const int64_t line_break_id,
|
||||
const int bs) {
|
||||
@@ -74,11 +126,19 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
think_end_id);
|
||||
WRAPPER_DUMP_PARAM2(ctx, line_break_id, bs);
|
||||
stop_flags);
|
||||
WRAPPER_DUMP_PARAM3(ctx, think_end_id, line_break_id, bs);
|
||||
WRAPPER_DUMP(ctx);
|
||||
if (ctx->dev().type() == api::kCPU) {
|
||||
assert(false);
|
||||
return cpu_wrapper(ctx,
|
||||
next_tokens,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
bs);
|
||||
}
|
||||
if (ctx->dev().type() == api::kXPU2 || ctx->dev().type() == api::kXPU3) {
|
||||
return xpu3_wrapper(ctx,
|
||||
@@ -86,6 +146,7 @@ int limit_thinking_content_length_kernel_v2(Context* ctx,
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
bs);
|
||||
|
||||
@@ -202,6 +202,8 @@ def xpu_post_process(
|
||||
max_think_lens = share_inputs["max_think_lens"]
|
||||
step_idx = share_inputs["step_idx"]
|
||||
limit_think_status = share_inputs["limit_think_status"]
|
||||
stop_flags = share_inputs["stop_flags"]
|
||||
eos_token_ids = share_inputs["eos_token_id"]
|
||||
if limit_strategy == "</think>":
|
||||
# for ernie-45-vl
|
||||
limit_thinking_content_length_v1(
|
||||
@@ -209,6 +211,8 @@ def xpu_post_process(
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
|
||||
think_end_id,
|
||||
)
|
||||
elif limit_strategy == "\n</think>\n\n":
|
||||
@@ -219,6 +223,7 @@ def xpu_post_process(
|
||||
max_think_lens,
|
||||
step_idx,
|
||||
limit_think_status,
|
||||
stop_flags,
|
||||
think_end_id,
|
||||
line_break_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user