support w4afp8 mtp (#5429)

This commit is contained in:
Sunny-bot1
2025-12-08 20:24:00 +08:00
committed by GitHub
parent 438c9f785a
commit 364197c4b5
2 changed files with 8 additions and 4 deletions

View File

@@ -930,11 +930,13 @@ class CutlassW4AFP8MoEMethod(CutlassMoEMethod):
Paddle cutlass load weight process. Paddle cutlass load weight process.
""" """
if not layer.is_quantized: if not layer.is_quantized:
prefix_layer_name = layer.fd_config.model_config.prefix_layer_name
logger.info( logger.info(
f"Rotating ernie.layers.{layer.layer_idx}.mlp.experts.[{layer.ep_rank * layer.num_local_experts},{layer.ep_rank * layer.num_local_experts + layer.num_local_experts}).down_proj.weight..." f"Rotating ernie.{prefix_layer_name}.{layer.layer_idx}.mlp.experts.[{layer.ep_rank * layer.num_local_experts},{layer.ep_rank * layer.num_local_experts + layer.num_local_experts}).down_proj.weight..."
) )
rotate_model( rotate_model(
state_dict, state_dict,
prefix_layer_name,
layer.layer_idx, layer.layer_idx,
layer.num_local_experts, layer.num_local_experts,
layer.hidden_size, layer.hidden_size,

View File

@@ -141,7 +141,9 @@ def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
raise ValueError(f"Unknown mode {mode}") raise ValueError(f"Unknown mode {mode}")
def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, moe_intermediate_size=3584, ep_rank=0): def rotate_model(
state_dict, prefix_layer_name, layer_idx, moe_num_experts, hidden_size, moe_intermediate_size, ep_rank=0
):
with paddle.no_grad(): with paddle.no_grad():
# collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size] # collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2") Q_ffn2, moe_block_size = get_orthogonal_matrix(size=moe_intermediate_size, mode="hadamard_ffn2")
@@ -149,7 +151,7 @@ def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, mo
expert_list = [ expert_list = [
get_tensor( get_tensor(
state_dict[ state_dict[
f"ernie.layers.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight" f"ernie.{prefix_layer_name}.{layer_idx}.mlp.experts.{ep_rank * moe_num_experts + expert_idx}.down_proj.weight"
] ]
) )
for expert_idx in range(moe_num_experts) for expert_idx in range(moe_num_experts)
@@ -159,7 +161,7 @@ def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, mo
for expert_idx in range(moe_num_experts): for expert_idx in range(moe_num_experts):
rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size] rotated_weight = new_moe_weight[:, expert_idx * hidden_size : (expert_idx + 1) * hidden_size]
expert_idx_local = ep_rank * moe_num_experts + expert_idx expert_idx_local = ep_rank * moe_num_experts + expert_idx
state_dict[f"ernie.layers.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = ( state_dict[f"ernie.{prefix_layer_name}.{layer_idx}.mlp.experts.{expert_idx_local}.down_proj.weight"] = (
rotated_weight.cpu() rotated_weight.cpu()
) )
del moe_weight, new_moe_weight, rotated_weight del moe_weight, new_moe_weight, rotated_weight