mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 11:56:44 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			71 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			71 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| 
 | |
| from fastdeploy.model_executor.ops.gpu import rejection_top_p_sampling
 | |
| 
 | |
| 
 | |
| class TestRejectionTopPSampling(unittest.TestCase):
 | |
|     def setUp(self):
 | |
|         """Initialize common test data"""
 | |
|         self.batch_size = 10
 | |
|         self.vocab_size = 103424
 | |
|         paddle.seed(2023)
 | |
| 
 | |
|         # Generate test data once for all tests
 | |
|         self.pre_norm_prob_np = np.random.rand(self.batch_size, self.vocab_size).astype(np.float32)
 | |
|         self.paddle_pre_norm_prob = paddle.to_tensor(self.pre_norm_prob_np)
 | |
|         self.paddle_norm_prob = self.paddle_pre_norm_prob / self.paddle_pre_norm_prob.sum(axis=-1, keepdim=True)
 | |
| 
 | |
|     def test_top_p_sampling_reject_case1(self):
 | |
|         """Test with fixed top_p=0.8 and different random seeds"""
 | |
|         top_p_paddle = paddle.full((self.batch_size,), 0.8)
 | |
| 
 | |
|         # Test with different seeds
 | |
|         for seed in [1024, 2033, 2033]:
 | |
|             samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed)
 | |
|             self._validate_samples(samples)
 | |
| 
 | |
|             # Basic validation
 | |
|             self.assertTrue(paddle.all(samples >= 0))
 | |
|             self.assertTrue(paddle.all(samples < self.vocab_size))
 | |
| 
 | |
|     def test_top_p_sampling_reject_case2(self):
 | |
|         """Test with varying top_p values across batch"""
 | |
|         top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0)
 | |
|         samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1)
 | |
| 
 | |
|         self._validate_samples(samples)
 | |
| 
 | |
|         # Additional check that we're getting different results for different top_p
 | |
|         unique_samples = len(paddle.unique(samples))
 | |
|         print(f"Unique samples: {unique_samples}")
 | |
|         self.assertGreater(unique_samples, 1)  # Should have some diversity
 | |
| 
 | |
|     def _validate_samples(self, samples):
 | |
|         """Common validation for all test cases"""
 | |
|         self.assertTrue(paddle.all(samples >= 0))
 | |
|         self.assertTrue(paddle.all(samples < self.vocab_size))
 | |
| 
 | |
|         # Check dtype
 | |
|         self.assertEqual(samples.dtype, paddle.int64)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
