support W4A8 EPLB (#3075)

This commit is contained in:
Yuan Xiaolan
2025-07-30 14:34:12 +08:00
committed by GitHub
parent 159767717d
commit 35935da9e5
5 changed files with 14 additions and 11 deletions

View File

@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight( up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict, state_dict,
up_gate_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key, down_proj_expert_weight_key,

View File

@@ -343,10 +343,12 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None) up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None) down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight( up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
state_dict, layer.load_experts_weight(
up_gate_proj_expert_weight_key, state_dict,
down_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
)
) )
up_gate_proj_weight_scale = [] up_gate_proj_weight_scale = []
@@ -356,7 +358,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
down_proj_in_scale = [] down_proj_in_scale = []
if layer.ep_size > 1: if layer.ep_size > 1:
for expert_idx in range(layer.num_experts): for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]) scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
up_gate_proj_in_scale_all_experts.append(scale_tensor) up_gate_proj_in_scale_all_experts.append(scale_tensor)
@@ -507,7 +509,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight( up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict, state_dict,
up_gate_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key, down_proj_expert_weight_key,

View File

@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None) up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None) down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight( up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict, state_dict,
up_gate_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key, down_proj_expert_weight_key,

View File

@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None) up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None) down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight( up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict, state_dict,
up_gate_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key, down_proj_expert_weight_key,

View File

@@ -238,6 +238,7 @@ class FusedMoE(nn.Layer):
self.expert_id_offset + self.num_local_experts, self.expert_id_offset + self.num_local_experts,
) )
] ]
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
if self.redundant_table_manger is not None: if self.redundant_table_manger is not None:
( (
ep_rank_to_expert_id_list, ep_rank_to_expert_id_list,
@@ -309,7 +310,7 @@ class FusedMoE(nn.Layer):
self.fd_config.model_config.model, self.fd_config.model_config.model,
) )
) )
return up_gate_proj_weights, down_proj_weights, logical_expert_ids return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
def extract_moe_ffn_weights(self, state_dict: dict): def extract_moe_ffn_weights(self, state_dict: dict):
""" """
@@ -332,7 +333,7 @@ class FusedMoE(nn.Layer):
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none." assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none." assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight( up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = self.load_experts_weight(
state_dict, state_dict,
up_gate_proj_expert_weight_key, up_gate_proj_expert_weight_key,
down_proj_expert_weight_key, down_proj_expert_weight_key,