polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -13,17 +13,20 @@
# limitations under the License.
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
class TestRejectionTopPSampling(unittest.TestCase):
def setUp(self):
"""Initialize common test data"""
self.batch_size = 10
self.vocab_size = 103424
paddle.seed(2023)
# Generate test data once for all tests
self.pre_norm_prob_np = np.random.rand(self.batch_size, self.vocab_size).astype(np.float32)
self.paddle_pre_norm_prob = paddle.to_tensor(self.pre_norm_prob_np)
@@ -32,12 +35,12 @@ class TestRejectionTopPSampling(unittest.TestCase):
def test_top_p_sampling_reject_case1(self):
"""Test with fixed top_p=0.8 and different random seeds"""
top_p_paddle = paddle.full((self.batch_size,), 0.8)
# Test with different seeds
for seed in [1024, 2033, 2033]:
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
self._validate_samples(samples)
# Basic validation
self.assertTrue(paddle.all(samples >= 0))
self.assertTrue(paddle.all(samples < self.vocab_size))
@@ -46,9 +49,9 @@ class TestRejectionTopPSampling(unittest.TestCase):
"""Test with varying top_p values across batch"""
top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
self._validate_samples(samples)
# Additional check that we're getting different results for different top_p
unique_samples = len(paddle.unique(samples))
print(f"Unique samples: {unique_samples}")
@@ -58,9 +61,10 @@ class TestRejectionTopPSampling(unittest.TestCase):
"""Common validation for all test cases"""
self.assertTrue(paddle.all(samples >= 0))
self.assertTrue(paddle.all(samples < self.vocab_size))
# Check dtype
self.assertEqual(samples.dtype, paddle.int64)
if __name__ == "__main__":
unittest.main()
unittest.main()