add real gate_correction_bias weight to mock un-balanced dispatch (#4676)

This commit is contained in:
周周周
2025-10-30 15:13:21 +08:00
committed by GitHub
parent f1de348cbf
commit 8b9c9463cd

View File

@@ -25,6 +25,396 @@ from fastdeploy.worker.worker_process import init_distributed_environment
paddle.set_default_dtype("bfloat16")
gate_correction_bias_real_data = paddle.to_tensor(
[
32.8339,
32.8231,
32.8151,
32.8131,
32.8317,
32.8343,
32.8356,
32.8270,
32.8344,
32.8342,
32.8126,
32.8299,
32.8282,
32.8254,
32.8320,
32.8280,
32.8303,
32.8351,
32.8364,
32.8347,
32.8179,
32.8349,
32.8322,
32.8323,
32.8360,
32.8351,
32.8059,
32.8352,
32.8303,
32.8334,
32.8283,
32.8265,
32.8344,
32.8307,
32.8271,
32.8343,
32.8326,
32.8327,
32.8349,
32.8356,
32.8303,
32.8327,
32.8310,
32.8363,
32.8274,
32.8335,
32.8350,
32.8255,
32.8298,
32.8141,
32.8218,
32.8362,
32.8126,
32.7902,
32.8314,
32.8356,
32.8177,
32.8333,
32.8352,
32.8354,
32.8334,
32.8325,
32.7971,
32.8319,
32.8222,
32.8284,
32.8288,
32.8355,
32.8351,
32.8356,
32.8338,
32.8346,
32.7737,
32.8317,
32.8357,
32.8345,
32.8347,
32.8360,
32.8289,
32.8268,
32.8164,
32.8324,
32.8363,
32.8308,
32.8352,
32.8302,
32.8345,
32.8298,
32.8057,
32.8229,
32.8355,
32.8325,
32.8350,
32.8357,
32.8315,
32.8327,
32.8263,
32.8342,
32.8165,
32.8349,
32.8310,
32.8101,
32.8101,
32.8081,
32.8341,
32.8313,
32.8331,
32.8299,
32.8320,
32.7941,
32.8277,
32.8287,
32.8326,
32.8331,
32.8360,
32.8295,
32.8255,
32.8330,
32.8279,
32.8210,
32.7921,
32.8348,
32.8271,
32.8297,
32.8211,
32.8353,
32.8339,
32.8335,
32.8275,
32.8245,
32.8287,
32.8352,
32.8318,
32.8354,
32.8110,
32.8347,
32.8340,
32.8322,
32.8341,
32.8316,
32.8328,
32.8341,
32.8354,
32.8264,
32.8362,
32.8352,
32.8293,
32.8292,
32.8328,
32.8316,
32.8329,
32.8308,
32.8307,
32.8170,
32.8345,
32.8356,
32.8176,
32.8326,
32.8288,
32.8355,
32.8346,
32.8337,
32.8049,
32.8315,
32.8337,
32.8352,
32.7991,
32.8304,
32.8348,
32.8316,
32.8358,
32.8279,
32.8348,
32.8326,
32.8215,
32.8281,
32.8344,
32.8309,
32.8355,
32.8337,
32.8276,
32.8250,
32.8340,
32.8322,
32.8317,
32.8274,
32.8363,
32.8277,
32.8345,
32.8342,
32.8343,
32.8355,
32.8326,
32.8299,
32.8322,
32.8351,
32.8356,
32.7925,
32.8362,
32.8170,
32.8323,
32.8335,
32.8339,
32.8193,
32.8340,
32.8362,
32.8323,
32.8328,
32.8328,
32.8296,
32.8297,
32.8344,
32.8254,
32.8341,
32.8345,
32.7967,
32.8228,
32.8363,
32.8356,
32.8317,
32.8362,
32.8302,
32.8356,
32.8239,
32.8304,
32.8323,
32.8335,
32.8196,
32.8354,
32.6991,
32.8350,
32.8337,
32.8314,
32.8274,
32.8232,
32.8305,
32.8349,
32.8246,
32.8343,
32.8339,
32.7849,
32.8359,
32.8353,
32.8352,
32.8348,
32.8095,
32.8301,
32.8350,
32.8340,
32.8353,
32.8343,
32.8344,
32.8312,
32.8350,
32.8327,
32.8231,
32.8325,
32.8352,
32.8352,
32.8293,
32.8357,
32.8337,
32.8335,
32.8348,
32.8321,
32.8153,
32.8352,
32.8265,
32.8326,
32.8361,
32.8357,
32.8312,
32.8347,
32.8152,
32.8340,
32.8272,
32.8352,
32.8331,
32.8324,
32.7952,
32.8170,
32.8356,
32.8360,
32.8298,
32.8356,
32.8331,
32.8317,
32.8349,
32.8269,
32.8323,
32.8354,
32.8350,
32.8226,
32.8002,
32.8205,
32.8329,
32.8319,
32.8297,
32.8282,
32.8356,
32.8303,
32.8349,
32.8337,
32.8247,
32.8279,
32.8309,
32.8225,
32.8337,
32.8356,
32.8105,
32.8353,
32.8361,
32.8297,
32.8313,
32.8313,
32.8363,
32.8357,
32.8357,
32.8363,
32.7806,
32.8306,
32.8347,
32.8248,
32.8334,
32.8356,
32.8324,
32.8327,
32.8284,
32.8351,
32.8349,
32.8351,
32.8171,
32.8317,
32.8363,
32.8346,
32.8335,
32.8307,
32.7907,
32.8229,
32.8346,
32.8298,
32.8336,
32.8313,
32.8349,
32.8219,
32.8354,
32.8337,
32.8294,
32.8306,
32.8322,
32.8290,
32.8333,
32.8327,
32.8279,
32.8283,
32.8338,
32.8310,
32.8351,
32.8171,
32.8310,
32.8323,
32.8324,
32.8215,
32.8314,
32.8333,
32.8353,
32.8184,
32.8344,
32.8280,
32.8352,
32.8361,
32.8308,
32.8271,
32.8335,
32.8236,
32.8350,
32.8325,
32.8330,
32.8228,
32.8352,
32.8258,
32.8343,
32.8338,
32.8292,
],
dtype="float32",
)
class FuseMoEWrapper(paddle.nn.Layer):
def __init__(
@@ -90,6 +480,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
topk_group=4,
n_group=8,
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
# gate_correction_bias = gate_correction_bias_real_data
)
moe_layer = self.fused_moe
@@ -179,14 +570,16 @@ class TestFusedMoE(unittest.TestCase):
nnodes = (ep_size + 7) // 8
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes)
# 这行代码必须保留,否则影响均匀性!
paddle.seed(ep_rank + 100)
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes)
moe_cuda_graphs = [None] * 100
cache_hidden_states = [None] * 100
for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]):
test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256]
# test_token_nums = [1024 * i for i in [1,2,4,8,16,32]]
for idx, num_tokens in enumerate(test_token_nums):
cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16)