[BugFix] fix speculate_limit_thinking_content_length (#5590)

* fix speculate_limit_thinking_content_length

* update
This commit is contained in:
Yuanle Liu
2025-12-16 20:31:45 +08:00
committed by GitHub
parent 7140939c51
commit 867803ae10
5 changed files with 1 additions and 57 deletions

View File

@@ -1029,7 +1029,6 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id);
@@ -1040,7 +1039,6 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id);

View File

@@ -22,7 +22,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
const int64_t* eos_token_ids,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
bool* stop_flags,
const int64_t think_end_id,
const int tokens_per_step,
@@ -106,7 +105,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v1(
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
@@ -119,7 +117,6 @@ void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const paddle::Tensor& eos_token_ids,
const int64_t think_end_id) {
@@ -134,7 +131,6 @@ void SpeculateLimitThinkingContentLengthV1(
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<bool*>(stop_flags.data<bool>()),
think_end_id,
tokens_per_step,
@@ -148,7 +144,6 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder",
"stop_flags",
"eos_token_ids"})
.Attrs({"think_end_id: int64_t"})

View File

@@ -25,7 +25,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int64_t* step_idx,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const bool* stop_flags,
const int64_t think_end_id,
const int64_t line_break_id,
@@ -115,7 +114,6 @@ __global__ void speculate_limit_thinking_content_length_kernel_v2(
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}
accept_num[bid] = new_accept_num;
@@ -128,7 +126,6 @@ void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& stop_flags,
const int64_t think_end_id,
const int64_t line_break_id) {
@@ -141,7 +138,6 @@ void SpeculateLimitThinkingContentLengthV2(
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
stop_flags.data<bool>(),
think_end_id,
line_break_id,
@@ -155,7 +151,6 @@ PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v2)
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder",
"stop_flags"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})

View File

@@ -144,7 +144,6 @@ def speculate_limit_thinking_content_length(
step_idx: paddle.Tensor,
limit_think_status: paddle.Tensor,
accept_num: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
stop_flags: paddle.Tensor,
eos_token_ids: paddle.Tensor,
think_end_id: int,
@@ -158,7 +157,6 @@ def speculate_limit_thinking_content_length(
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids, # 处理由于模型效果问题导致思考过程中输出eos token的问题
think_end_id,
@@ -172,7 +170,6 @@ def speculate_limit_thinking_content_length(
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -451,7 +448,6 @@ def post_process_specualate(
step_idx=share_inputs["step_idx"],
limit_think_status=share_inputs["limit_think_status"],
accept_num=share_inputs["accept_num"],
seq_lens_decoder=share_inputs["seq_lens_decoder"],
think_end_id=think_end_id,
line_break_id=line_break_id,
)

View File

@@ -36,7 +36,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([5, 8], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5, 8], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -48,7 +47,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -75,7 +73,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([12], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([4], dtype="int32")
seq_lens_decoder = paddle.to_tensor([12], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -87,7 +84,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -99,9 +95,8 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
assert next_tokens.numpy()[0, 1] == 999 # Token at step 10, replaced with think_end_id
assert accept_num.numpy()[0] == 2 # Only accept first 2 tokens
assert limit_think_status.numpy()[0] == 2 # Status updated to 2
# step_idx and seq_lens_decoder should be adjusted
# step_idx should be adjusted
assert step_idx.numpy()[0] == 10 # 12 - (4-2) = 10
assert seq_lens_decoder.numpy()[0] == 10 # 12 - (4-2) = 10
def test_model_naturally_generates_think_end_id(self):
"""Test when model naturally generates think_end_id in accepted tokens"""
@@ -110,7 +105,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([5], dtype="int64") # step 3-5
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -122,7 +116,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -140,7 +133,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([100], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([100], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -152,7 +144,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -170,7 +161,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([10], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([0], dtype="int32") # No tokens accepted
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -182,7 +172,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -199,7 +188,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([10], dtype="int64")
limit_think_status = paddle.to_tensor([3], dtype="int32") # Terminal status
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -211,7 +199,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -228,7 +215,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([9], dtype="int64") # base step = 9-2+1 = 8
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([9], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -239,7 +225,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -258,7 +243,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx = paddle.to_tensor([6, 8, 50], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([6, 8, 50], dtype="int32")
stop_flags = paddle.to_tensor([False, False, False], dtype="bool")
eos_token_ids = paddle.to_tensor([[2], [2]], dtype="int64")
think_end_id = 999
@@ -270,7 +254,6 @@ class TestSpeculateLimitThinkingContentLengthV1(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
eos_token_ids,
think_end_id,
@@ -302,7 +285,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([5, 8], dtype="int64")
limit_think_status = paddle.to_tensor([0, 0], dtype="int32")
accept_num = paddle.to_tensor([3, 2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5, 8], dtype="int32")
stop_flags = paddle.to_tensor([False, False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -314,7 +296,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -335,7 +316,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([12], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([5], dtype="int32")
seq_lens_decoder = paddle.to_tensor([12], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -347,7 +327,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -358,7 +337,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
assert limit_think_status.numpy()[0] == 1
assert accept_num.numpy()[0] == 1 # Truncated after 1st token
assert step_idx.numpy()[0] == 8 # 12 - (5-1)
assert seq_lens_decoder.numpy()[0] == 8
def test_injection_sequence_steps(self):
"""Test each step of the injection sequence: \n, </think>, \n, \n"""
@@ -371,7 +349,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([5], dtype="int64") # base_step = 5-1+1 = 5
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
speculate_limit_thinking_content_length_v2(
@@ -380,7 +357,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -393,7 +369,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([6], dtype="int64") # base_step = 6
limit_think_status = paddle.to_tensor([1], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([6], dtype="int32")
speculate_limit_thinking_content_length_v2(
next_tokens,
@@ -401,7 +376,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -414,7 +388,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([7], dtype="int64")
limit_think_status = paddle.to_tensor([1], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([7], dtype="int32")
speculate_limit_thinking_content_length_v2(
next_tokens,
@@ -422,7 +395,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -435,7 +407,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([8], dtype="int64")
limit_think_status = paddle.to_tensor([1], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([8], dtype="int32")
speculate_limit_thinking_content_length_v2(
next_tokens,
@@ -443,7 +414,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -458,7 +428,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([5], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([3], dtype="int32")
seq_lens_decoder = paddle.to_tensor([5], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -470,7 +439,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -486,7 +454,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([10], dtype="int64")
limit_think_status = paddle.to_tensor([2], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -498,7 +465,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -514,7 +480,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([100], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([2], dtype="int32")
seq_lens_decoder = paddle.to_tensor([100], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -526,7 +491,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -543,7 +507,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([10], dtype="int64")
limit_think_status = paddle.to_tensor([0], dtype="int32")
accept_num = paddle.to_tensor([0], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -555,7 +518,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,
@@ -572,7 +534,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx = paddle.to_tensor([10], dtype="int64")
limit_think_status = paddle.to_tensor([3], dtype="int32")
accept_num = paddle.to_tensor([1], dtype="int32")
seq_lens_decoder = paddle.to_tensor([10], dtype="int32")
stop_flags = paddle.to_tensor([False], dtype="bool")
think_end_id = 999
line_break_id = 888
@@ -584,7 +545,6 @@ class TestSpeculateLimitThinkingContentLengthV2(unittest.TestCase):
step_idx,
limit_think_status,
accept_num,
seq_lens_decoder,
stop_flags,
think_end_id,
line_break_id,