mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[FixBug] compute early stopping with real batch size (#3418)
* [FixBug] compute early stopping with real batch size * update * fix test_sampler
This commit is contained in:
@@ -67,16 +67,17 @@ class RepetitionEarlyStopper(EarlyStopper):
|
||||
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
# Get the probability score corresponding to next_tokens in this step
|
||||
next_scores = paddle.index_sample(probs, next_tokens)
|
||||
real_bsz = probs.shape[0]
|
||||
|
||||
# Sliding window: Move left one grid and insert new score
|
||||
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
||||
self.trunc_scores[:, -1:] = next_scores
|
||||
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
|
||||
self.trunc_scores[:real_bsz, -1:] = next_scores
|
||||
|
||||
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||
|
||||
# Add the stop flags
|
||||
stop_flags[need_trunc_all] = True
|
||||
stop_flags[need_trunc_all[:real_bsz]] = True
|
||||
|
||||
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||
|
@@ -26,7 +26,6 @@ done
|
||||
failed_tests_file="failed_tests.log"
|
||||
> "$failed_tests_file"
|
||||
disabled_tests=(
|
||||
layers/test_sampler.py
|
||||
layers/test_append_attention.py
|
||||
layers/test_attention.py
|
||||
operators/test_rejection_top_p_sampling.py
|
||||
@@ -36,7 +35,6 @@ disabled_tests=(
|
||||
operators/test_stop_generation.py
|
||||
operators/test_air_topp_sampling.py
|
||||
operators/test_fused_moe.py
|
||||
layers/test_repetition_early_stopper.py
|
||||
operators/test_stop_generation_multi_ends.py
|
||||
graph_optimization/test_cuda_graph.py
|
||||
)
|
||||
|
@@ -170,7 +170,69 @@ def test_consistency():
|
||||
actual = triggered_step_triton[i]
|
||||
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||
|
||||
print("Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||
|
||||
|
||||
def test_consistency_with_real_batch_size():
|
||||
batch_size = 20
|
||||
real_batch_size = 15
|
||||
vocab_size = 103424
|
||||
window_size = 3000
|
||||
threshold = 0.9
|
||||
eos_token_id = vocab_size
|
||||
max_steps = 10
|
||||
|
||||
fixed_token_id = np.random.randint(0, vocab_size)
|
||||
early_stop_batch_id = np.random.randint(0, real_batch_size)
|
||||
|
||||
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||
trigger_step_flags = dict(trigger_step_flags)
|
||||
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||
stopper_normal = RepetitionEarlyStopper()
|
||||
stopper_normal.initialize(batch_size, cfg)
|
||||
stopper_triton = RepetitionEarlyStopper()
|
||||
stopper_triton.initialize(batch_size, cfg)
|
||||
|
||||
next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
|
||||
next_tokens_triton = next_tokens_normal.clone()
|
||||
|
||||
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
|
||||
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
|
||||
|
||||
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
|
||||
stop_flags_triton = stop_flags_normal.clone()
|
||||
|
||||
triggered_step_normal = [None] * batch_size
|
||||
triggered_step_triton = [None] * batch_size
|
||||
|
||||
for step in range(max_steps):
|
||||
|
||||
flags = [trigger_step_flags[i] for i in range(real_batch_size)]
|
||||
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||
probs = paddle.to_tensor(probs_np)
|
||||
|
||||
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
|
||||
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
|
||||
|
||||
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
|
||||
|
||||
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
|
||||
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
|
||||
|
||||
out_normal = stop_flags_normal.numpy()
|
||||
out_triton = stop_flags_triton.numpy()
|
||||
for i in range(real_batch_size):
|
||||
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
|
||||
triggered_step_normal[i] = step
|
||||
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
|
||||
triggered_step_triton[i] = step
|
||||
|
||||
for i in range(batch_size):
|
||||
expected = triggered_step_normal[i]
|
||||
actual = triggered_step_triton[i]
|
||||
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||
|
||||
print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||
|
||||
|
||||
def test_performance():
|
||||
@@ -232,4 +294,5 @@ def test_performance():
|
||||
if __name__ == "__main__":
|
||||
test_repetition_early_stopper()
|
||||
test_consistency()
|
||||
test_consistency_with_real_batch_size()
|
||||
test_performance()
|
||||
|
@@ -57,6 +57,7 @@ def _create_default_sampling_metadata(
|
||||
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
|
||||
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
|
||||
min_p=paddle.randn([batch_size]),
|
||||
seed=paddle.to_tensor([[2025]]),
|
||||
)
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
Reference in New Issue
Block a user