mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
122 lines
4.7 KiB
Python
122 lines
4.7 KiB
Python
"""
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import paddle
|
|
|
|
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
|
|
|
|
|
|
class TestMoeRouting(unittest.TestCase):
|
|
def setUp(self):
|
|
paddle.seed(2024)
|
|
print(paddle.device.cuda.get_device_properties())
|
|
print(paddle.__git_commit__)
|
|
|
|
def native_group_topk(
|
|
self,
|
|
gating_output: paddle.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int,
|
|
topk_group: int,
|
|
routed_scaling_factor: float,
|
|
e_score_correction_bias: paddle.Tensor,
|
|
):
|
|
original_scores = paddle.nn.functional.sigmoid(gating_output)
|
|
if len(e_score_correction_bias.shape) == 1:
|
|
e_score_correction_bias = e_score_correction_bias.unsqueeze(0)
|
|
scores = original_scores + e_score_correction_bias
|
|
|
|
num_token, n_experts = scores.shape
|
|
group_scores = scores.reshape([num_token, num_expert_group, -1]).topk(2, axis=-1)[0].sum(axis=-1)
|
|
group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group]
|
|
group_mask = paddle.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.put_along_axis_(group_idx, 1.0, axis=-1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand([num_token, num_expert_group, n_experts // num_expert_group])
|
|
.reshape([num_token, -1])
|
|
)
|
|
tmp_scores = scores.masked_fill(~score_mask.astype(paddle.bool), float("-inf"))
|
|
|
|
topk_ids = paddle.topk(tmp_scores, topk, axis=1)[1]
|
|
topk_weights = paddle.take_along_axis(original_scores, topk_ids, axis=1)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / paddle.sum(topk_weights, axis=1, keepdim=True)
|
|
|
|
if routed_scaling_factor != 1.0:
|
|
topk_weights = topk_weights * routed_scaling_factor
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
def test_group_topk(self):
|
|
|
|
renormalize = True
|
|
|
|
test_cases = [
|
|
# (num_experts, n_group, topk_group, top_k, routed_scaling_factor)
|
|
(128, 1, 1, 8, 1.0), # glm45-air
|
|
(256, 8, 4, 8, 2.5), # deepseek
|
|
]
|
|
|
|
for case_tuple in test_cases:
|
|
num_experts, n_group, topk_group, top_k, routed_scaling_factor = case_tuple
|
|
for num_tokens in [1, 32, 64, 128]:
|
|
gating_output = paddle.rand([num_tokens, num_experts])
|
|
e_score_correction_bias = paddle.rand([1, num_experts])
|
|
|
|
ref_topk_values, ref_topk_idx = self.native_group_topk(
|
|
gating_output=gating_output,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=n_group,
|
|
topk_group=topk_group,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
)
|
|
|
|
new_score, topk_values, topk_idx = get_moe_scores(
|
|
gating_output=gating_output,
|
|
n_group=n_group,
|
|
topk_group=topk_group,
|
|
top_k=top_k,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
renormalize=renormalize,
|
|
)
|
|
|
|
equal_topk_value = paddle.allclose(topk_values, ref_topk_values, atol=1e-03, rtol=1e-03).item()
|
|
equal_topk_ids = paddle.allclose(
|
|
topk_idx.cast("int32"), ref_topk_idx.cast("int32"), atol=0.0, rtol=0.0
|
|
).item()
|
|
print(
|
|
f"Test Case[{case_tuple}], num_tokens = {num_tokens}, equal_topk_value: {equal_topk_value}, equal_topk_ids: {equal_topk_ids}"
|
|
)
|
|
if not equal_topk_value:
|
|
print(f"ref_topk_values = {ref_topk_values}")
|
|
print(f"topk_values = {topk_values}")
|
|
if not equal_topk_ids:
|
|
print(f"ref_topk_idx = {ref_topk_idx}")
|
|
print(f"topk_idx = {topk_idx}")
|
|
assert equal_topk_value and equal_topk_ids
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|