mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
add real gate_correction_bias weight to mock un-balanced dispatch (#4676)
This commit is contained in:
@@ -25,6 +25,396 @@ from fastdeploy.worker.worker_process import init_distributed_environment
|
||||
|
||||
paddle.set_default_dtype("bfloat16")
|
||||
|
||||
gate_correction_bias_real_data = paddle.to_tensor(
|
||||
[
|
||||
32.8339,
|
||||
32.8231,
|
||||
32.8151,
|
||||
32.8131,
|
||||
32.8317,
|
||||
32.8343,
|
||||
32.8356,
|
||||
32.8270,
|
||||
32.8344,
|
||||
32.8342,
|
||||
32.8126,
|
||||
32.8299,
|
||||
32.8282,
|
||||
32.8254,
|
||||
32.8320,
|
||||
32.8280,
|
||||
32.8303,
|
||||
32.8351,
|
||||
32.8364,
|
||||
32.8347,
|
||||
32.8179,
|
||||
32.8349,
|
||||
32.8322,
|
||||
32.8323,
|
||||
32.8360,
|
||||
32.8351,
|
||||
32.8059,
|
||||
32.8352,
|
||||
32.8303,
|
||||
32.8334,
|
||||
32.8283,
|
||||
32.8265,
|
||||
32.8344,
|
||||
32.8307,
|
||||
32.8271,
|
||||
32.8343,
|
||||
32.8326,
|
||||
32.8327,
|
||||
32.8349,
|
||||
32.8356,
|
||||
32.8303,
|
||||
32.8327,
|
||||
32.8310,
|
||||
32.8363,
|
||||
32.8274,
|
||||
32.8335,
|
||||
32.8350,
|
||||
32.8255,
|
||||
32.8298,
|
||||
32.8141,
|
||||
32.8218,
|
||||
32.8362,
|
||||
32.8126,
|
||||
32.7902,
|
||||
32.8314,
|
||||
32.8356,
|
||||
32.8177,
|
||||
32.8333,
|
||||
32.8352,
|
||||
32.8354,
|
||||
32.8334,
|
||||
32.8325,
|
||||
32.7971,
|
||||
32.8319,
|
||||
32.8222,
|
||||
32.8284,
|
||||
32.8288,
|
||||
32.8355,
|
||||
32.8351,
|
||||
32.8356,
|
||||
32.8338,
|
||||
32.8346,
|
||||
32.7737,
|
||||
32.8317,
|
||||
32.8357,
|
||||
32.8345,
|
||||
32.8347,
|
||||
32.8360,
|
||||
32.8289,
|
||||
32.8268,
|
||||
32.8164,
|
||||
32.8324,
|
||||
32.8363,
|
||||
32.8308,
|
||||
32.8352,
|
||||
32.8302,
|
||||
32.8345,
|
||||
32.8298,
|
||||
32.8057,
|
||||
32.8229,
|
||||
32.8355,
|
||||
32.8325,
|
||||
32.8350,
|
||||
32.8357,
|
||||
32.8315,
|
||||
32.8327,
|
||||
32.8263,
|
||||
32.8342,
|
||||
32.8165,
|
||||
32.8349,
|
||||
32.8310,
|
||||
32.8101,
|
||||
32.8101,
|
||||
32.8081,
|
||||
32.8341,
|
||||
32.8313,
|
||||
32.8331,
|
||||
32.8299,
|
||||
32.8320,
|
||||
32.7941,
|
||||
32.8277,
|
||||
32.8287,
|
||||
32.8326,
|
||||
32.8331,
|
||||
32.8360,
|
||||
32.8295,
|
||||
32.8255,
|
||||
32.8330,
|
||||
32.8279,
|
||||
32.8210,
|
||||
32.7921,
|
||||
32.8348,
|
||||
32.8271,
|
||||
32.8297,
|
||||
32.8211,
|
||||
32.8353,
|
||||
32.8339,
|
||||
32.8335,
|
||||
32.8275,
|
||||
32.8245,
|
||||
32.8287,
|
||||
32.8352,
|
||||
32.8318,
|
||||
32.8354,
|
||||
32.8110,
|
||||
32.8347,
|
||||
32.8340,
|
||||
32.8322,
|
||||
32.8341,
|
||||
32.8316,
|
||||
32.8328,
|
||||
32.8341,
|
||||
32.8354,
|
||||
32.8264,
|
||||
32.8362,
|
||||
32.8352,
|
||||
32.8293,
|
||||
32.8292,
|
||||
32.8328,
|
||||
32.8316,
|
||||
32.8329,
|
||||
32.8308,
|
||||
32.8307,
|
||||
32.8170,
|
||||
32.8345,
|
||||
32.8356,
|
||||
32.8176,
|
||||
32.8326,
|
||||
32.8288,
|
||||
32.8355,
|
||||
32.8346,
|
||||
32.8337,
|
||||
32.8049,
|
||||
32.8315,
|
||||
32.8337,
|
||||
32.8352,
|
||||
32.7991,
|
||||
32.8304,
|
||||
32.8348,
|
||||
32.8316,
|
||||
32.8358,
|
||||
32.8279,
|
||||
32.8348,
|
||||
32.8326,
|
||||
32.8215,
|
||||
32.8281,
|
||||
32.8344,
|
||||
32.8309,
|
||||
32.8355,
|
||||
32.8337,
|
||||
32.8276,
|
||||
32.8250,
|
||||
32.8340,
|
||||
32.8322,
|
||||
32.8317,
|
||||
32.8274,
|
||||
32.8363,
|
||||
32.8277,
|
||||
32.8345,
|
||||
32.8342,
|
||||
32.8343,
|
||||
32.8355,
|
||||
32.8326,
|
||||
32.8299,
|
||||
32.8322,
|
||||
32.8351,
|
||||
32.8356,
|
||||
32.7925,
|
||||
32.8362,
|
||||
32.8170,
|
||||
32.8323,
|
||||
32.8335,
|
||||
32.8339,
|
||||
32.8193,
|
||||
32.8340,
|
||||
32.8362,
|
||||
32.8323,
|
||||
32.8328,
|
||||
32.8328,
|
||||
32.8296,
|
||||
32.8297,
|
||||
32.8344,
|
||||
32.8254,
|
||||
32.8341,
|
||||
32.8345,
|
||||
32.7967,
|
||||
32.8228,
|
||||
32.8363,
|
||||
32.8356,
|
||||
32.8317,
|
||||
32.8362,
|
||||
32.8302,
|
||||
32.8356,
|
||||
32.8239,
|
||||
32.8304,
|
||||
32.8323,
|
||||
32.8335,
|
||||
32.8196,
|
||||
32.8354,
|
||||
32.6991,
|
||||
32.8350,
|
||||
32.8337,
|
||||
32.8314,
|
||||
32.8274,
|
||||
32.8232,
|
||||
32.8305,
|
||||
32.8349,
|
||||
32.8246,
|
||||
32.8343,
|
||||
32.8339,
|
||||
32.7849,
|
||||
32.8359,
|
||||
32.8353,
|
||||
32.8352,
|
||||
32.8348,
|
||||
32.8095,
|
||||
32.8301,
|
||||
32.8350,
|
||||
32.8340,
|
||||
32.8353,
|
||||
32.8343,
|
||||
32.8344,
|
||||
32.8312,
|
||||
32.8350,
|
||||
32.8327,
|
||||
32.8231,
|
||||
32.8325,
|
||||
32.8352,
|
||||
32.8352,
|
||||
32.8293,
|
||||
32.8357,
|
||||
32.8337,
|
||||
32.8335,
|
||||
32.8348,
|
||||
32.8321,
|
||||
32.8153,
|
||||
32.8352,
|
||||
32.8265,
|
||||
32.8326,
|
||||
32.8361,
|
||||
32.8357,
|
||||
32.8312,
|
||||
32.8347,
|
||||
32.8152,
|
||||
32.8340,
|
||||
32.8272,
|
||||
32.8352,
|
||||
32.8331,
|
||||
32.8324,
|
||||
32.7952,
|
||||
32.8170,
|
||||
32.8356,
|
||||
32.8360,
|
||||
32.8298,
|
||||
32.8356,
|
||||
32.8331,
|
||||
32.8317,
|
||||
32.8349,
|
||||
32.8269,
|
||||
32.8323,
|
||||
32.8354,
|
||||
32.8350,
|
||||
32.8226,
|
||||
32.8002,
|
||||
32.8205,
|
||||
32.8329,
|
||||
32.8319,
|
||||
32.8297,
|
||||
32.8282,
|
||||
32.8356,
|
||||
32.8303,
|
||||
32.8349,
|
||||
32.8337,
|
||||
32.8247,
|
||||
32.8279,
|
||||
32.8309,
|
||||
32.8225,
|
||||
32.8337,
|
||||
32.8356,
|
||||
32.8105,
|
||||
32.8353,
|
||||
32.8361,
|
||||
32.8297,
|
||||
32.8313,
|
||||
32.8313,
|
||||
32.8363,
|
||||
32.8357,
|
||||
32.8357,
|
||||
32.8363,
|
||||
32.7806,
|
||||
32.8306,
|
||||
32.8347,
|
||||
32.8248,
|
||||
32.8334,
|
||||
32.8356,
|
||||
32.8324,
|
||||
32.8327,
|
||||
32.8284,
|
||||
32.8351,
|
||||
32.8349,
|
||||
32.8351,
|
||||
32.8171,
|
||||
32.8317,
|
||||
32.8363,
|
||||
32.8346,
|
||||
32.8335,
|
||||
32.8307,
|
||||
32.7907,
|
||||
32.8229,
|
||||
32.8346,
|
||||
32.8298,
|
||||
32.8336,
|
||||
32.8313,
|
||||
32.8349,
|
||||
32.8219,
|
||||
32.8354,
|
||||
32.8337,
|
||||
32.8294,
|
||||
32.8306,
|
||||
32.8322,
|
||||
32.8290,
|
||||
32.8333,
|
||||
32.8327,
|
||||
32.8279,
|
||||
32.8283,
|
||||
32.8338,
|
||||
32.8310,
|
||||
32.8351,
|
||||
32.8171,
|
||||
32.8310,
|
||||
32.8323,
|
||||
32.8324,
|
||||
32.8215,
|
||||
32.8314,
|
||||
32.8333,
|
||||
32.8353,
|
||||
32.8184,
|
||||
32.8344,
|
||||
32.8280,
|
||||
32.8352,
|
||||
32.8361,
|
||||
32.8308,
|
||||
32.8271,
|
||||
32.8335,
|
||||
32.8236,
|
||||
32.8350,
|
||||
32.8325,
|
||||
32.8330,
|
||||
32.8228,
|
||||
32.8352,
|
||||
32.8258,
|
||||
32.8343,
|
||||
32.8338,
|
||||
32.8292,
|
||||
],
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
|
||||
class FuseMoEWrapper(paddle.nn.Layer):
|
||||
def __init__(
|
||||
@@ -90,6 +480,7 @@ class FuseMoEWrapper(paddle.nn.Layer):
|
||||
topk_group=4,
|
||||
n_group=8,
|
||||
gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32),
|
||||
# gate_correction_bias = gate_correction_bias_real_data
|
||||
)
|
||||
moe_layer = self.fused_moe
|
||||
|
||||
@@ -179,14 +570,16 @@ class TestFusedMoE(unittest.TestCase):
|
||||
|
||||
nnodes = (ep_size + 7) // 8
|
||||
|
||||
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes)
|
||||
|
||||
# 这行代码必须保留,否则影响均匀性!
|
||||
paddle.seed(ep_rank + 100)
|
||||
|
||||
fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes)
|
||||
|
||||
moe_cuda_graphs = [None] * 100
|
||||
cache_hidden_states = [None] * 100
|
||||
for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]):
|
||||
test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256]
|
||||
# test_token_nums = [1024 * i for i in [1,2,4,8,16,32]]
|
||||
for idx, num_tokens in enumerate(test_token_nums):
|
||||
|
||||
cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user