mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[FixBug] compute early stopping with real batch size (#3418)
* [FixBug] compute early stopping with real batch size * update * fix test_sampler
This commit is contained in:
@@ -67,16 +67,17 @@ class RepetitionEarlyStopper(EarlyStopper):
|
||||
def process_normal(self, probs: paddle.Tensor, next_tokens: paddle.Tensor, stop_flags: paddle.Tensor):
|
||||
# Get the probability score corresponding to next_tokens in this step
|
||||
next_scores = paddle.index_sample(probs, next_tokens)
|
||||
real_bsz = probs.shape[0]
|
||||
|
||||
# Sliding window: Move left one grid and insert new score
|
||||
self.trunc_scores[:, :-1] = self.trunc_scores[:, 1:]
|
||||
self.trunc_scores[:, -1:] = next_scores
|
||||
self.trunc_scores[:real_bsz, :-1] = self.trunc_scores[:real_bsz, 1:]
|
||||
self.trunc_scores[:real_bsz, -1:] = next_scores
|
||||
|
||||
# Determine which samples need to be terminated: all trunc_scores are greater than threshold
|
||||
need_trunc_all = paddle.all(self.trunc_scores > self.threshold, axis=-1).unsqueeze(-1)
|
||||
|
||||
# Add the stop flags
|
||||
stop_flags[need_trunc_all] = True
|
||||
stop_flags[need_trunc_all[:real_bsz]] = True
|
||||
|
||||
# Reset trunc_scores of truncated samples to 0 to avoid false triggering in the next step
|
||||
reset_mask = need_trunc_all.tile([1, self.window_size])
|
||||
|
Reference in New Issue
Block a user