mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
fix op tests (#3398)
This commit is contained in:
@@ -35,10 +35,11 @@ class TestRejectionTopPSampling(unittest.TestCase):
|
||||
def test_top_p_sampling_reject_case1(self):
|
||||
"""Test with fixed top_p=0.8 and different random seeds"""
|
||||
top_p_paddle = paddle.full((self.batch_size,), 0.8)
|
||||
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
|
||||
|
||||
# Test with different seeds
|
||||
for seed in [1024, 2033, 2033]:
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, seed)
|
||||
self._validate_samples(samples)
|
||||
|
||||
# Basic validation
|
||||
@@ -48,13 +49,12 @@ class TestRejectionTopPSampling(unittest.TestCase):
|
||||
def test_top_p_sampling_reject_case2(self):
|
||||
"""Test with varying top_p values across batch"""
|
||||
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
|
||||
|
||||
top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64")
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, -1)
|
||||
self._validate_samples(samples)
|
||||
|
||||
# Additional check that we're getting different results for different top_p
|
||||
unique_samples = len(paddle.unique(samples))
|
||||
print(f"Unique samples: {unique_samples}")
|
||||
self.assertGreater(unique_samples, 1) # Should have some diversity
|
||||
|
||||
def _validate_samples(self, samples):
|
||||
|
Reference in New Issue
Block a user