Files
FastDeploy/tests/operators/test_fused_moe.py
YUNSHEN XIE 3a6058e445 Add stable ci (#3460)
* add stable ci

* fix

* update

* fix

* rename tests dir;fix stable ci bug

* add timeout limit

* update
2025-08-20 08:57:17 +08:00

232 lines
7.5 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test for moe ops"""
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.incubate.nn.functional import swiglu
from fastdeploy.model_executor.ops.gpu import (
fused_expert_moe,
moe_expert_dispatch,
moe_expert_ffn,
moe_expert_reduce,
)
# Set random seeds for reproducibility
paddle.seed(42)
np.random.seed(42)
class Expert(nn.Layer):
"""A single expert layer using SwiGLU activation."""
def __init__(self, d_model, d_feedforward):
super().__init__()
self.fc1 = nn.Linear(d_model, d_feedforward * 2) # *2 for SwiGLU
self.fc2 = nn.Linear(d_feedforward, d_model)
def forward(self, x):
"""forward"""
x = self.fc1(x)
x = swiglu(x)
return self.fc2(x)
class TestFusedMoeConsistency(unittest.TestCase):
"""Test case for verifying consistency between baseline and fused MoE implementations."""
@classmethod
def setUpClass(cls):
"""Class-level setup that runs once before all tests."""
cls.set_config()
paddle.set_default_dtype(cls.dtype)
@classmethod
def set_config(cls):
"""Set the configuration parameters for the test."""
cls.dtype = "bfloat16"
cls.batch_size = 8
cls.seq_len = 128
cls.num_experts = 16
cls.d_model = 8192
cls.d_feedforward = 128
cls.top_k = 4
cls.rtol = 1e-2
cls.atol = 1e-2
def setUp(self):
"""Test-level setup that runs before each test."""
self.init_experts()
self.prepare_data()
def init_experts(self):
"""Initialize expert layers and gate weights."""
self.experts = nn.LayerList([Expert(self.d_model, self.d_feedforward) for _ in range(self.num_experts)])
# Initialize gate weights
self.gate = nn.Linear(self.d_model, self.num_experts)
self.gate_weight = self.gate.weight.cast("float32")
def prepare_data(self):
"""Prepare input data and expert parameters."""
# Input tensor
self.x = paddle.randn([self.batch_size, self.seq_len, self.d_model], dtype=self.dtype)
# Stack expert parameters for fused operations
self.w0 = paddle.stack([e.fc1.weight for e in self.experts]).astype(self.dtype)
self.b0 = (
paddle.stack([e.fc1.bias for e in self.experts]).reshape([self.num_experts, 1, -1]).astype(self.dtype)
)
self.w1 = paddle.stack([e.fc2.weight for e in self.experts]).astype(self.dtype)
self.b1 = (
paddle.stack([e.fc2.bias for e in self.experts]).reshape([self.num_experts, 1, -1]).astype(self.dtype)
)
def baseline_forward(self, hidden_states):
"""Baseline implementation processing experts sequentially."""
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape([-1, hidden_dim])
# Routing computation
logits = paddle.matmul(hidden_states.cast("float32"), self.gate_weight)
weights = F.softmax(logits, axis=-1)
routing_weights, selected_experts = paddle.topk(weights, self.top_k, axis=-1)
# Initialize output
final_hidden_states = paddle.zeros_like(hidden_states)
expert_mask = paddle.transpose(F.one_hot(selected_experts, num_classes=self.num_experts), [2, 1, 0])
# Process each expert
for expert_id in range(self.num_experts):
idx, top_x = paddle.where(expert_mask[expert_id])
if top_x.size == 0: # Skip if no tokens for this expert
continue
current_state = paddle.index_select(hidden_states, top_x, axis=0)
expert_out = self.experts[expert_id](current_state)
current_hidden_states = expert_out * routing_weights[top_x, idx].reshape([-1, 1])
paddle.index_add_(
x=final_hidden_states,
index=top_x.squeeze(),
axis=0,
value=current_hidden_states.to(hidden_states.dtype),
)
return final_hidden_states.reshape([batch_size, seq_len, hidden_dim])
def fused_forward(self, x):
"""Fused MoE implementation using a single kernel."""
return fused_expert_moe(
x,
self.gate_weight,
self.w0,
self.w1,
self.b0,
None, # No bias for second part of SwiGLU
self.b1,
None, # No activation for second linear
"None", # No activation type
self.top_k,
False, # Not renormalizing topk
False, # Not using expert capacity
)
def split_forward(self, hidden_states):
"""Split implementation using separate dispatch/ffn/reduce ops."""
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape([-1, hidden_dim])
# Routing computation
logits = paddle.matmul(hidden_states.cast("float32"), self.gate_weight)
scores = F.softmax(logits, axis=-1)
# Dispatch tokens to experts
(
permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weights,
top_k_indices,
expert_idx_per_token,
) = moe_expert_dispatch(hidden_states, scores, None, None, self.top_k, False, topk_only_mode=True)
# Process through experts
ffn_out = moe_expert_ffn(
permute_input,
tokens_expert_prefix_sum,
self.w0,
self.w1,
self.b0,
None,
None,
None,
None,
"none",
False,
)
# Combine results
output = moe_expert_reduce(
ffn_out,
top_k_weights,
permute_indices_per_token,
top_k_indices,
None,
norm_topk_prob=False,
routed_scaling_factor=1.0,
)
return output.reshape([batch_size, seq_len, hidden_dim])
def test_consistency(self):
"""Test consistency between all three implementations."""
# Compute outputs
base_out = self.baseline_forward(self.x)
fused_out = self.fused_forward(self.x)
split_out = self.split_forward(self.x)
# Convert to float32 for comparison
base_out = base_out.cast("float32").numpy()
fused_out = fused_out.cast("float32").numpy()
split_out = split_out.cast("float32").numpy()
# Compare baseline vs fused
np.testing.assert_allclose(
base_out,
fused_out,
rtol=self.rtol,
atol=self.atol,
err_msg="Baseline and fused outputs differ",
)
# Compare baseline vs split
np.testing.assert_allclose(
base_out,
split_out,
rtol=self.rtol,
atol=self.atol,
err_msg="Baseline and split outputs differ",
)
if __name__ == "__main__":
unittest.main()