mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[FixBug] compute early stopping with real batch size (#3418)
* [FixBug] compute early stopping with real batch size * update * fix test_sampler
This commit is contained in:
@@ -67,16 +67,17 @@ class RepetitionEarlyStopper(EarlyStopper):
|
|||||||
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||||
# Get the probability score corresponding to next_tokens in this step
|
# Get the probability score corresponding to next_tokens in this step
|
||||||
next_scores = paddle.index_sample(probs, next_tokens)
|
next_scores = paddle.index_sample(probs, next_tokens)
|
||||||
|
real_bsz = probs.shape[0]
|
||||||
|
|
||||||
# Sliding window: Move left one grid and insert new score
|
# Sliding window: Move left one grid and insert new score
|
||||||
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
|
||||||
self.trunc_scores[:, -1:] = next_scores
|
self.trunc_scores[:real_bsz, -1:] = next_scores
|
||||||
|
|
||||||
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||||
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||||
|
|
||||||
# Add the stop flags
|
# Add the stop flags
|
||||||
stop_flags[need_trunc_all] = True
|
stop_flags[need_trunc_all[:real_bsz]] = True
|
||||||
|
|
||||||
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||||
reset_mask = need_trunc_all.tile([1, self.window_size])
|
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||||
|
@@ -26,7 +26,6 @@ done
|
|||||||
failed_tests_file="failed_tests.log"
|
failed_tests_file="failed_tests.log"
|
||||||
> "$failed_tests_file"
|
> "$failed_tests_file"
|
||||||
disabled_tests=(
|
disabled_tests=(
|
||||||
layers/test_sampler.py
|
|
||||||
layers/test_append_attention.py
|
layers/test_append_attention.py
|
||||||
layers/test_attention.py
|
layers/test_attention.py
|
||||||
operators/test_rejection_top_p_sampling.py
|
operators/test_rejection_top_p_sampling.py
|
||||||
@@ -36,7 +35,6 @@ disabled_tests=(
|
|||||||
operators/test_stop_generation.py
|
operators/test_stop_generation.py
|
||||||
operators/test_air_topp_sampling.py
|
operators/test_air_topp_sampling.py
|
||||||
operators/test_fused_moe.py
|
operators/test_fused_moe.py
|
||||||
layers/test_repetition_early_stopper.py
|
|
||||||
operators/test_stop_generation_multi_ends.py
|
operators/test_stop_generation_multi_ends.py
|
||||||
graph_optimization/test_cuda_graph.py
|
graph_optimization/test_cuda_graph.py
|
||||||
)
|
)
|
||||||
|
@@ -170,7 +170,69 @@ def test_consistency():
|
|||||||
actual = triggered_step_triton[i]
|
actual = triggered_step_triton[i]
|
||||||
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||||
|
|
||||||
print("Triton vs Normal: All tokens, states, and trigger timings match.")
|
print("[consistency]Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_consistency_with_real_batch_size():
|
||||||
|
batch_size = 20
|
||||||
|
real_batch_size = 15
|
||||||
|
vocab_size = 103424
|
||||||
|
window_size = 3000
|
||||||
|
threshold = 0.9
|
||||||
|
eos_token_id = vocab_size
|
||||||
|
max_steps = 10
|
||||||
|
|
||||||
|
fixed_token_id = np.random.randint(0, vocab_size)
|
||||||
|
early_stop_batch_id = np.random.randint(0, real_batch_size)
|
||||||
|
|
||||||
|
trigger_step_flags = [[i, np.random.randint(0, max_steps + 1)] for i in range(batch_size)]
|
||||||
|
trigger_step_flags = dict(trigger_step_flags)
|
||||||
|
cfg = EarlyStopConfig({"enable_early_stop": True, "window_size": window_size, "threshold": threshold})
|
||||||
|
stopper_normal = RepetitionEarlyStopper()
|
||||||
|
stopper_normal.initialize(batch_size, cfg)
|
||||||
|
stopper_triton = RepetitionEarlyStopper()
|
||||||
|
stopper_triton.initialize(batch_size, cfg)
|
||||||
|
|
||||||
|
next_tokens_normal = paddle.randint(0, vocab_size, shape=[real_batch_size, 1], dtype="int64")
|
||||||
|
next_tokens_triton = next_tokens_normal.clone()
|
||||||
|
|
||||||
|
next_tokens_normal[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
next_tokens_triton[early_stop_batch_id, 0] = fixed_token_id
|
||||||
|
|
||||||
|
stop_flags_normal = paddle.zeros_like(next_tokens_normal)
|
||||||
|
stop_flags_triton = stop_flags_normal.clone()
|
||||||
|
|
||||||
|
triggered_step_normal = [None] * batch_size
|
||||||
|
triggered_step_triton = [None] * batch_size
|
||||||
|
|
||||||
|
for step in range(max_steps):
|
||||||
|
|
||||||
|
flags = [trigger_step_flags[i] for i in range(real_batch_size)]
|
||||||
|
probs_np = simulate_step_probs(real_batch_size, early_stop_batch_id, fixed_token_id, vocab_size, step, flags)
|
||||||
|
probs = paddle.to_tensor(probs_np)
|
||||||
|
|
||||||
|
stopper_normal.process_normal(probs, next_tokens_normal, stop_flags_normal)
|
||||||
|
stopper_triton.process_triton(probs, next_tokens_triton, stop_flags_triton)
|
||||||
|
|
||||||
|
assert np.allclose(stop_flags_normal.numpy(), stop_flags_triton.numpy()), f"stop flags mismatch at step {step}"
|
||||||
|
|
||||||
|
trunc_scores_diff = paddle.abs(stopper_normal.trunc_scores - stopper_triton.trunc_scores)
|
||||||
|
assert paddle.all(trunc_scores_diff < 1e-5), f"trunc_scores mismatch at step {step}"
|
||||||
|
|
||||||
|
out_normal = stop_flags_normal.numpy()
|
||||||
|
out_triton = stop_flags_triton.numpy()
|
||||||
|
for i in range(real_batch_size):
|
||||||
|
if out_normal[i, 0] == eos_token_id and triggered_step_normal[i] is None:
|
||||||
|
triggered_step_normal[i] = step
|
||||||
|
if out_triton[i, 0] == eos_token_id and triggered_step_triton[i] is None:
|
||||||
|
triggered_step_triton[i] = step
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
expected = triggered_step_normal[i]
|
||||||
|
actual = triggered_step_triton[i]
|
||||||
|
assert expected == actual, f"Sample {i} triggered at different steps: {expected} vs {actual}"
|
||||||
|
|
||||||
|
print("[consistency_with_real_batch_size]Triton vs Normal: All tokens, states, and trigger timings match.")
|
||||||
|
|
||||||
|
|
||||||
def test_performance():
|
def test_performance():
|
||||||
@@ -232,4 +294,5 @@ def test_performance():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_repetition_early_stopper()
|
test_repetition_early_stopper()
|
||||||
test_consistency()
|
test_consistency()
|
||||||
|
test_consistency_with_real_batch_size()
|
||||||
test_performance()
|
test_performance()
|
||||||
|
@@ -57,6 +57,7 @@ def _create_default_sampling_metadata(
|
|||||||
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
|
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
|
||||||
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
|
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
|
||||||
min_p=paddle.randn([batch_size]),
|
min_p=paddle.randn([batch_size]),
|
||||||
|
seed=paddle.to_tensor([[2025]]),
|
||||||
)
|
)
|
||||||
return fake_sampling_metadata
|
return fake_sampling_metadata
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user