mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			231 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			231 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| """test for moe ops"""
 | |
| 
 | |
| import unittest
 | |
| 
 | |
| import numpy as np
 | |
| import paddle
 | |
| import paddle.nn.functional as F
 | |
| from paddle import nn
 | |
| from paddle.incubate.nn.functional import swiglu
 | |
| 
 | |
| from fastdeploy.model_executor.ops.gpu import (
 | |
|     fused_expert_moe,
 | |
|     moe_expert_dispatch,
 | |
|     moe_expert_ffn,
 | |
|     moe_expert_reduce,
 | |
| )
 | |
| 
 | |
| # Set random seeds for reproducibility
 | |
| paddle.seed(42)
 | |
| np.random.seed(42)
 | |
| 
 | |
| 
 | |
| class Expert(nn.Layer):
 | |
|     """A single expert layer using SwiGLU activation."""
 | |
| 
 | |
|     def __init__(self, d_model, d_feedforward):
 | |
|         super().__init__()
 | |
|         self.fc1 = nn.Linear(d_model, d_feedforward * 2)  # *2 for SwiGLU
 | |
|         self.fc2 = nn.Linear(d_feedforward, d_model)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         """forward"""
 | |
|         x = self.fc1(x)
 | |
|         x = swiglu(x)
 | |
|         return self.fc2(x)
 | |
| 
 | |
| 
 | |
| class TestFusedMoeConsistency(unittest.TestCase):
 | |
|     """Test case for verifying consistency between baseline and fused MoE implementations."""
 | |
| 
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
|         """Class-level setup that runs once before all tests."""
 | |
|         cls.set_config()
 | |
|         paddle.set_default_dtype(cls.dtype)
 | |
| 
 | |
|     @classmethod
 | |
|     def set_config(cls):
 | |
|         """Set the configuration parameters for the test."""
 | |
|         cls.dtype = "bfloat16"
 | |
|         cls.batch_size = 8
 | |
|         cls.seq_len = 128
 | |
|         cls.num_experts = 16
 | |
|         cls.d_model = 8192
 | |
|         cls.d_feedforward = 128
 | |
|         cls.top_k = 4
 | |
|         cls.rtol = 1e-2
 | |
|         cls.atol = 1e-2
 | |
| 
 | |
|     def setUp(self):
 | |
|         """Test-level setup that runs before each test."""
 | |
|         self.init_experts()
 | |
|         self.prepare_data()
 | |
| 
 | |
|     def init_experts(self):
 | |
|         """Initialize expert layers and gate weights."""
 | |
|         self.experts = nn.LayerList([Expert(self.d_model, self.d_feedforward) for _ in range(self.num_experts)])
 | |
| 
 | |
|         # Initialize gate weights
 | |
|         self.gate = nn.Linear(self.d_model, self.num_experts)
 | |
|         self.gate_weight = self.gate.weight.cast("float32")
 | |
| 
 | |
|     def prepare_data(self):
 | |
|         """Prepare input data and expert parameters."""
 | |
|         # Input tensor
 | |
|         self.x = paddle.randn([self.batch_size, self.seq_len, self.d_model], dtype=self.dtype)
 | |
| 
 | |
|         # Stack expert parameters for fused operations
 | |
|         self.w0 = paddle.stack([e.fc1.weight for e in self.experts]).astype(self.dtype)
 | |
|         self.b0 = (
 | |
|             paddle.stack([e.fc1.bias for e in self.experts]).reshape([self.num_experts, 1, -1]).astype(self.dtype)
 | |
|         )
 | |
|         self.w1 = paddle.stack([e.fc2.weight for e in self.experts]).astype(self.dtype)
 | |
|         self.b1 = (
 | |
|             paddle.stack([e.fc2.bias for e in self.experts]).reshape([self.num_experts, 1, -1]).astype(self.dtype)
 | |
|         )
 | |
| 
 | |
|     def baseline_forward(self, hidden_states):
 | |
|         """Baseline implementation processing experts sequentially."""
 | |
|         batch_size, seq_len, hidden_dim = hidden_states.shape
 | |
|         hidden_states = hidden_states.reshape([-1, hidden_dim])
 | |
| 
 | |
|         # Routing computation
 | |
|         logits = paddle.matmul(hidden_states.cast("float32"), self.gate_weight)
 | |
|         weights = F.softmax(logits, axis=-1)
 | |
|         routing_weights, selected_experts = paddle.topk(weights, self.top_k, axis=-1)
 | |
| 
 | |
|         # Initialize output
 | |
|         final_hidden_states = paddle.zeros_like(hidden_states)
 | |
|         expert_mask = paddle.transpose(F.one_hot(selected_experts, num_classes=self.num_experts), [2, 1, 0])
 | |
| 
 | |
|         # Process each expert
 | |
|         for expert_id in range(self.num_experts):
 | |
|             idx, top_x = paddle.where(expert_mask[expert_id])
 | |
|             if top_x.size == 0:  # Skip if no tokens for this expert
 | |
|                 continue
 | |
| 
 | |
|             current_state = paddle.index_select(hidden_states, top_x, axis=0)
 | |
|             expert_out = self.experts[expert_id](current_state)
 | |
| 
 | |
|             current_hidden_states = expert_out * routing_weights[top_x, idx].reshape([-1, 1])
 | |
|             paddle.index_add_(
 | |
|                 x=final_hidden_states,
 | |
|                 index=top_x.squeeze(),
 | |
|                 axis=0,
 | |
|                 value=current_hidden_states.to(hidden_states.dtype),
 | |
|             )
 | |
| 
 | |
|         return final_hidden_states.reshape([batch_size, seq_len, hidden_dim])
 | |
| 
 | |
|     def fused_forward(self, x):
 | |
|         """Fused MoE implementation using a single kernel."""
 | |
|         return fused_expert_moe(
 | |
|             x,
 | |
|             self.gate_weight,
 | |
|             self.w0,
 | |
|             self.w1,
 | |
|             self.b0,
 | |
|             None,  # No bias for second part of SwiGLU
 | |
|             self.b1,
 | |
|             None,  # No activation for second linear
 | |
|             "None",  # No activation type
 | |
|             self.top_k,
 | |
|             False,  # Not renormalizing topk
 | |
|             False,  # Not using expert capacity
 | |
|         )
 | |
| 
 | |
|     def split_forward(self, hidden_states):
 | |
|         """Split implementation using separate dispatch/ffn/reduce ops."""
 | |
|         batch_size, seq_len, hidden_dim = hidden_states.shape
 | |
|         hidden_states = hidden_states.reshape([-1, hidden_dim])
 | |
| 
 | |
|         # Routing computation
 | |
|         logits = paddle.matmul(hidden_states.cast("float32"), self.gate_weight)
 | |
|         scores = F.softmax(logits, axis=-1)
 | |
| 
 | |
|         # Dispatch tokens to experts
 | |
|         (
 | |
|             permute_input,
 | |
|             tokens_expert_prefix_sum,
 | |
|             permute_indices_per_token,
 | |
|             top_k_weights,
 | |
|             top_k_indices,
 | |
|         ) = moe_expert_dispatch(hidden_states, scores, None, self.top_k, False, topk_only_mode=True)
 | |
| 
 | |
|         # Process through experts
 | |
|         ffn_out = moe_expert_ffn(
 | |
|             permute_input,
 | |
|             tokens_expert_prefix_sum,
 | |
|             self.w0,
 | |
|             self.w1,
 | |
|             self.b0,
 | |
|             None,
 | |
|             None,
 | |
|             None,
 | |
|             None,
 | |
|             "none",
 | |
|             False,
 | |
|         )
 | |
| 
 | |
|         # Combine results
 | |
|         output = moe_expert_reduce(
 | |
|             ffn_out,
 | |
|             top_k_weights,
 | |
|             permute_indices_per_token,
 | |
|             top_k_indices,
 | |
|             None,
 | |
|             norm_topk_prob=False,
 | |
|             routed_scaling_factor=1.0,
 | |
|         )
 | |
| 
 | |
|         return output.reshape([batch_size, seq_len, hidden_dim])
 | |
| 
 | |
|     def test_consistency(self):
 | |
|         """Test consistency between all three implementations."""
 | |
|         # Compute outputs
 | |
|         base_out = self.baseline_forward(self.x)
 | |
|         fused_out = self.fused_forward(self.x)
 | |
|         split_out = self.split_forward(self.x)
 | |
| 
 | |
|         # Convert to float32 for comparison
 | |
|         base_out = base_out.cast("float32").numpy()
 | |
|         fused_out = fused_out.cast("float32").numpy()
 | |
|         split_out = split_out.cast("float32").numpy()
 | |
| 
 | |
|         # Compare baseline vs fused
 | |
|         np.testing.assert_allclose(
 | |
|             base_out,
 | |
|             fused_out,
 | |
|             rtol=self.rtol,
 | |
|             atol=self.atol,
 | |
|             err_msg="Baseline and fused outputs differ",
 | |
|         )
 | |
| 
 | |
|         # Compare baseline vs split
 | |
|         np.testing.assert_allclose(
 | |
|             base_out,
 | |
|             split_out,
 | |
|             rtol=self.rtol,
 | |
|             atol=self.atol,
 | |
|             err_msg="Baseline and split outputs differ",
 | |
|         )
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     unittest.main()
 | 
