mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -13,17 +13,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
|
||||
|
||||
|
||||
class TestRejectionTopPSampling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Initialize common test data"""
|
||||
self.batch_size = 10
|
||||
self.vocab_size = 103424
|
||||
paddle.seed(2023)
|
||||
|
||||
|
||||
# Generate test data once for all tests
|
||||
self.pre_norm_prob_np = np.random.rand(self.batch_size, self.vocab_size).astype(np.float32)
|
||||
self.paddle_pre_norm_prob = paddle.to_tensor(self.pre_norm_prob_np)
|
||||
@@ -32,12 +35,12 @@ class TestRejectionTopPSampling(unittest.TestCase):
|
||||
def test_top_p_sampling_reject_case1(self):
|
||||
"""Test with fixed top_p=0.8 and different random seeds"""
|
||||
top_p_paddle = paddle.full((self.batch_size,), 0.8)
|
||||
|
||||
|
||||
# Test with different seeds
|
||||
for seed in [1024, 2033, 2033]:
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
|
||||
self._validate_samples(samples)
|
||||
|
||||
|
||||
# Basic validation
|
||||
self.assertTrue(paddle.all(samples >= 0))
|
||||
self.assertTrue(paddle.all(samples < self.vocab_size))
|
||||
@@ -46,9 +49,9 @@ class TestRejectionTopPSampling(unittest.TestCase):
|
||||
"""Test with varying top_p values across batch"""
|
||||
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
|
||||
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
|
||||
|
||||
|
||||
self._validate_samples(samples)
|
||||
|
||||
|
||||
# Additional check that we're getting different results for different top_p
|
||||
unique_samples = len(paddle.unique(samples))
|
||||
print(f"Unique samples: {unique_samples}")
|
||||
@@ -58,9 +61,10 @@ class TestRejectionTopPSampling(unittest.TestCase):
|
||||
"""Common validation for all test cases"""
|
||||
self.assertTrue(paddle.all(samples >= 0))
|
||||
self.assertTrue(paddle.all(samples < self.vocab_size))
|
||||
|
||||
|
||||
# Check dtype
|
||||
self.assertEqual(samples.dtype, paddle.int64)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user