support W4A8 EPLB (#3075)

This commit is contained in:
Yuan Xiaolan
2025-07-30 14:34:12 +08:00
committed by GitHub
parent 159767717d
commit 35935da9e5
5 changed files with 14 additions and 11 deletions

View File

@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -343,11 +343,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
)
)
up_gate_proj_weight_scale = []
down_proj_weight_scale = []
@@ -356,7 +358,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
down_proj_in_scale = []
if layer.ep_size > 1:
for expert_idx in range(layer.num_experts):
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
up_gate_proj_in_scale_all_experts.append(scale_tensor)
@@ -507,7 +509,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,

View File

@@ -238,6 +238,7 @@ class FusedMoE(nn.Layer):
self.expert_id_offset + self.num_local_experts,
)
]
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
if self.redundant_table_manger is not None:
(
ep_rank_to_expert_id_list,
@@ -309,7 +310,7 @@ class FusedMoE(nn.Layer):
self.fd_config.model_config.model,
)
)
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
def extract_moe_ffn_weights(self, state_dict: dict):
"""
@@ -332,7 +333,7 @@ class FusedMoE(nn.Layer):
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = self.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,