mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
【Inference Optimize】Update wint2 weight n-dim reorder (#3042)
This commit is contained in:
@@ -135,6 +135,17 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
up_gate_proj_code_zp = paddle.stack(up_gate_proj_code_zp, axis=0)
|
||||
down_proj_code_zp = paddle.stack(down_proj_code_zp, axis=0)
|
||||
|
||||
# Here we pre-arrange the n-dim weight matrix
|
||||
w1_shape = up_gate_proj_weight.shape
|
||||
up_gate_proj_weight = up_gate_proj_weight.reshape([w1_shape[0], w1_shape[1] // 16, 16, w1_shape[2] // 8, 8])
|
||||
up_gate_proj_weight = paddle.transpose(up_gate_proj_weight, perm=[0, 3, 1, 4, 2])
|
||||
up_gate_proj_weight = up_gate_proj_weight.reshape(w1_shape)
|
||||
|
||||
w2_shape = down_proj_weight.shape
|
||||
down_proj_weight = down_proj_weight.reshape([w2_shape[0], w2_shape[1] // 16, 16, w2_shape[2] // 8, 8])
|
||||
down_proj_weight = paddle.transpose(down_proj_weight, perm=[0, 3, 1, 4, 2])
|
||||
down_proj_weight = down_proj_weight.reshape(w2_shape)
|
||||
|
||||
name_tensor_map = {
|
||||
"up_gate_proj_weight": up_gate_proj_weight,
|
||||
"down_proj_weight": down_proj_weight,
|
||||
|
Reference in New Issue
Block a user