【Inference Optimize】Update wint2 weight n-dim reorder (#3042)

This commit is contained in:
AIbin
2025-07-28 16:31:56 +08:00
committed by GitHub
parent bddf403576
commit ec52d39e68

View File

@@ -135,6 +135,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0)
down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0)
# Here we pre-arrange the n-dim weight matrix
w1_shape = up_gate_proj_weight.shape
up_gate_proj_weight = up_gate_proj_weight.reshape([w1_shape[0], w1_shape[1] // 16, 16, w1_shape[2] // 8, 8])
up_gate_proj_weight = paddle.transpose(up_gate_proj_weight, perm=[0, 3, 1, 4, 2])
up_gate_proj_weight = up_gate_proj_weight.reshape(w1_shape)
w2_shape = down_proj_weight.shape
down_proj_weight = down_proj_weight.reshape([w2_shape[0], w2_shape[1] // 16, 16, w2_shape[2] // 8, 8])
down_proj_weight = paddle.transpose(down_proj_weight, perm=[0, 3, 1, 4, 2])
down_proj_weight = down_proj_weight.reshape(w2_shape)
name_tensor_map = {
"up_gate_proj_weight": up_gate_proj_weight,
"down_proj_weight": down_proj_weight,