[FixBug] compute early stopping with real batch size (#3418)

* [FixBug] compute early stopping with real batch size

* update

* fix test_sampler
This commit is contained in:
Zero Rains
2025-08-19 13:09:21 +08:00
committed by GitHub
parent 3a7a20d191
commit 8b12c80f90
4 changed files with 69 additions and 6 deletions

View File

@@ -67,16 +67,17 @@ class RepetitionEarlyStopper(EarlyStopper):
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
real_bsz = probs.shape[0]
# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
self.trunc_scores[:real_bsz, -1:] = next_scores
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
# Add the stop flags
stop_flags[need_trunc_all] = True
stop_flags[need_trunc_all[:real_bsz]] = True
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])

View File

@@ -26,7 +26,6 @@ done
failed_tests_file="failed_tests.log"
> "$failed_tests_file"
disabled_tests=(
layers/test_sampler.py
layers/test_append_attention.py
layers/test_attention.py
operators/test_rejection_top_p_sampling.py
@@ -36,7 +35,6 @@ disabled_tests=(
operators/test_stop_generation.py
operators/test_air_topp_sampling.py
operators/test_fused_moe.py
layers/test_repetition_early_stopper.py
operators/test_stop_generation_multi_ends.py
graph_optimization/test_cuda_graph.py
)

View File

@@ -170,7 +170,69 @@ def test_consistency():
actual = triggered_step_triton[i]
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
print("Triton vs Normal: All tokens, states, and trigger timings match.")
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")
def test_consistency_with_real_batch_size():
batch_size = 20
real_batch_size = 15
vocab_size = 103424
window_size = 3000
threshold = 0.9
eos_token_id = vocab_size
max_steps = 10
fixed_token_id = np.random.randint(0, vocab_size)
early_stop_batch_id = np.random.randint(0, real_batch_size)
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
trigger_step_flags = dict(trigger_step_flags)
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
stopper_normal = RepetitionEarlyStopper()
stopper_normal.initialize(batch_size, cfg)
stopper_triton = RepetitionEarlyStopper()
stopper_triton.initialize(batch_size, cfg)
next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
next_tokens_triton = next_tokens_normal.clone()
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
stop_flags_triton = stop_flags_normal.clone()
triggered_step_normal = [None] * batch_size
triggered_step_triton = [None] * batch_size
for step in range(max_steps):
flags = [trigger_step_flags[i] for i in range(real_batch_size)]
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
probs = paddle.to_tensor(probs_np)
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
out_normal = stop_flags_normal.numpy()
out_triton = stop_flags_triton.numpy()
for i in range(real_batch_size):
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
triggered_step_normal[i] = step
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
triggered_step_triton[i] = step
for i in range(batch_size):
expected = triggered_step_normal[i]
actual = triggered_step_triton[i]
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")
def test_performance():
@@ -232,4 +294,5 @@ def test_performance():
if __name__ == "__main__":
test_repetition_early_stopper()
test_consistency()
test_consistency_with_real_batch_size()
test_performance()

View File

@@ -57,6 +57,7 @@ def _create_default_sampling_metadata(
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
min_p=paddle.randn([batch_size]),
seed=paddle.to_tensor([[2025]]),
)
return fake_sampling_metadata