fix op tests (#3398)

This commit is contained in:
gaoziyuan
2025-08-14 16:45:25 +08:00
committed by GitHub
parent 2e7831185f
commit 0ea8712018
2 changed files with 6 additions and 5 deletions

View File

@@ -35,10 +35,11 @@ class TestRejectionTopPSampling(unittest.TestCase):
def test_top_p_sampling_reject_case1(self):
"""Test with fixed top_p=0.8 and different random seeds"""
top_p_paddle = paddle.full((self.batch_size,), 0.8)
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
# Test with different seeds
for seed in [1024, 2033, 2033]:
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, seed)
self._validate_samples(samples)
# Basic validation
@@ -48,13 +49,12 @@ class TestRejectionTopPSampling(unittest.TestCase):
def test_top_p_sampling_reject_case2(self):
"""Test with varying top_p values across batch"""
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, -1)
self._validate_samples(samples)
# Additional check that we're getting different results for different top_p
unique_samples = len(paddle.unique(samples))
print(f"Unique samples: {unique_samples}")
self.assertGreater(unique_samples, 1) # Should have some diversity
def _validate_samples(self, samples):