mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 19:36:42 +08:00
support W4A8 EPLB (#3075)
This commit is contained in:
@@ -276,7 +276,7 @@ class GCUWeightOnlyMoEMethod(GCUFusedMoeMethod):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
|||||||
@@ -343,10 +343,12 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
up_gate_proj_expert_in_scale_key = layer.weight_key_map.get("up_gate_proj_expert_in_scale_key", None)
|
||||||
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
down_proj_expert_in_scale_key = layer.weight_key_map.get("down_proj_expert_in_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
|
||||||
state_dict,
|
layer.load_experts_weight(
|
||||||
up_gate_proj_expert_weight_key,
|
state_dict,
|
||||||
down_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
|
down_proj_expert_weight_key,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
up_gate_proj_weight_scale = []
|
up_gate_proj_weight_scale = []
|
||||||
@@ -356,7 +358,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
|
|||||||
down_proj_in_scale = []
|
down_proj_in_scale = []
|
||||||
|
|
||||||
if layer.ep_size > 1:
|
if layer.ep_size > 1:
|
||||||
for expert_idx in range(layer.num_experts):
|
for expert_idx in ep_rank_to_expert_id_list:
|
||||||
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
|
||||||
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
up_gate_proj_in_scale_all_experts.append(scale_tensor)
|
||||||
|
|
||||||
@@ -507,7 +509,7 @@ class CutlassWeightOnlyMoEMethod(CutlassMoEMethod):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
|||||||
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
up_gate_proj_expert_weight_scale_key = layer.weight_key_map.get("up_gate_proj_expert_weight_scale_key", None)
|
||||||
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
down_proj_expert_weight_scale_key = layer.weight_key_map.get("down_proj_expert_weight_scale_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class CutlassWint2FusedMoeMethod(Wint2MoeMethod):
|
|||||||
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
up_gate_proj_expert_code_zp_key = layer.weight_key_map.get("up_gate_proj_expert_code_zp_key", None)
|
||||||
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
down_proj_expert_code_zp_key = layer.weight_key_map.get("down_proj_expert_code_zp_key", None)
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, _ = layer.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, _, _ = layer.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.expert_id_offset + self.num_local_experts,
|
self.expert_id_offset + self.num_local_experts,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
ep_rank_to_expert_id_list = [i for i in range(self.num_experts)]
|
||||||
if self.redundant_table_manger is not None:
|
if self.redundant_table_manger is not None:
|
||||||
(
|
(
|
||||||
ep_rank_to_expert_id_list,
|
ep_rank_to_expert_id_list,
|
||||||
@@ -309,7 +310,7 @@ class FusedMoE(nn.Layer):
|
|||||||
self.fd_config.model_config.model,
|
self.fd_config.model_config.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return up_gate_proj_weights, down_proj_weights, logical_expert_ids
|
return up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list
|
||||||
|
|
||||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||||
"""
|
"""
|
||||||
@@ -332,7 +333,7 @@ class FusedMoE(nn.Layer):
|
|||||||
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
assert up_gate_proj_expert_weight_key is not None, "up_gate_proj_expert_weight_key should not be none."
|
||||||
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
assert down_proj_expert_weight_key is not None, "down_proj_expert_weight_key should not be none."
|
||||||
|
|
||||||
up_gate_proj_weights, down_proj_weights, logical_expert_ids = self.load_experts_weight(
|
up_gate_proj_weights, down_proj_weights, logical_expert_ids, _ = self.load_experts_weight(
|
||||||
state_dict,
|
state_dict,
|
||||||
up_gate_proj_expert_weight_key,
|
up_gate_proj_expert_weight_key,
|
||||||
down_proj_expert_weight_key,
|
down_proj_expert_weight_key,
|
||||||
|
|||||||
Reference in New Issue
Block a user