[FixBug] compute early stopping with real batch size (#3418)

* [FixBug] compute early stopping with real batch size

* update

* fix test_sampler
This commit is contained in:
Zero Rains
2025-08-19 13:09:21 +08:00
committed by GitHub
parent 3a7a20d191
commit 8b12c80f90
4 changed files with 69 additions and 6 deletions

View File

@@ -67,16 +67,17 @@ class RepetitionEarlyStopper(EarlyStopper):
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
# Get the probability score corresponding to next_tokens in this step
next_scores = paddle.index_sample(probs, next_tokens)
real_bsz = probs.shape[0]
# Sliding window: Move left one grid and insert new score
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
self.trunc_scores[:, -1:] = next_scores
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
self.trunc_scores[:real_bsz, -1:] = next_scores
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
# Add the stop flags
stop_flags[need_trunc_all] = True
stop_flags[need_trunc_all[:real_bsz]] = True
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
reset_mask = need_trunc_all.tile([1, self.window_size])