mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 19:36:42 +08:00
support W4A8 EPLB (#3075)
This commit is contained in:
@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
||||
@@ -343,11 +343,13 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||
layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
)
|
||||
)
|
||||
|
||||
up_gate_proj_weight_scale = []
|
||||
down_proj_weight_scale = []
|
||||
@@ -356,7 +358,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
||||
down_proj_in_scale = []
|
||||
|
||||
if layer.ep_size > 1:
|
||||
for expert_idx in range(layer.num_experts):
|
||||
for expert_idx in ep_rank_to_expert_id_list:
|
||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||
|
||||
@@ -507,7 +509,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
||||
@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
||||
@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
||||
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
||||
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
||||
@@ -238,6 +238,7 @@ class FusedMoE(nn.Layer):
|
||||
self.expert_id_offset + self.num_local_experts,
|
||||
)
|
||||
]
|
||||
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
|
||||
if self.redundant_table_manger is not None:
|
||||
(
|
||||
ep_rank_to_expert_id_list,
|
||||
@@ -309,7 +310,7 @@ class FusedMoE(nn.Layer):
|
||||
self.fd_config.model_config.model,
|
||||
)
|
||||
)
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
|
||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
|
||||
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
"""
|
||||
@@ -332,7 +333,7 @@ class FusedMoE(nn.Layer):
|
||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
|
||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = self.load_experts_weight(
|
||||
state_dict,
|
||||
up_gate_proj_expert_weight_key,
|
||||
down_proj_expert_weight_key,
|
||||
|
||||
Reference in New Issue
Block a user