diff --git a/test/operators/test_fused_moe.py b/test/operators/test_fused_moe.py index ce78e05c1..74548e0d9 100644 --- a/test/operators/test_fused_moe.py +++ b/test/operators/test_fused_moe.py @@ -165,7 +165,8 @@ class TestFusedMoeConsistency(unittest.TestCase): permute_indices_per_token, top_k_weights, top_k_indices, - ) = moe_expert_dispatch(hidden_states, scores, None, self.top_k, False, topk_only_mode=True) + expert_idx_per_token, + ) = moe_expert_dispatch(hidden_states, scores, None, None, self.top_k, False, topk_only_mode=True) # Process through experts ffn_out = moe_expert_ffn( diff --git a/test/operators/test_rejection_top_p_sampling.py b/test/operators/test_rejection_top_p_sampling.py index f034763c4..22213dbfb 100644 --- a/test/operators/test_rejection_top_p_sampling.py +++ b/test/operators/test_rejection_top_p_sampling.py @@ -35,10 +35,11 @@ class TestRejectionTopPSampling(unittest.TestCase): def test_top_p_sampling_reject_case1(self): """Test with fixed top_p=0.8 and different random seeds""" top_p_paddle = paddle.full((self.batch_size,), 0.8) + top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64") # Test with different seeds for seed in [1024, 2033, 2033]: - samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, seed) + samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, seed) self._validate_samples(samples) # Basic validation @@ -48,13 +49,12 @@ class TestRejectionTopPSampling(unittest.TestCase): def test_top_p_sampling_reject_case2(self): """Test with varying top_p values across batch""" top_p_paddle = paddle.uniform(shape=[self.batch_size], min=0.1, max=1.0) - samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, -1) - + top_k_paddle = paddle.full((self.batch_size,), 20).cast("int64") + samples = rejection_top_p_sampling(self.paddle_norm_prob, top_p_paddle, top_k_paddle, -1) self._validate_samples(samples) # Additional check that we're getting different results for different top_p unique_samples = len(paddle.unique(samples)) - print(f"Unique samples: {unique_samples}") self.assertGreater(unique_samples, 1) # Should have some diversity def _validate_samples(self, samples):