Fix eplb part3 (#3206)

* fix_eplb

* fix eplb part3
This commit is contained in:
xiaoxiaohehe001
2025-08-05 10:58:17 +08:00
committed by GitHub
parent 869626b0f4
commit 794ab9705f
2 changed files with 47 additions and 9 deletions

View File

@@ -345,9 +345,7 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_weights, down_proj_weights, logical_expert_ids, ep_rank_to_expert_id_list = (
layer.load_experts_weight(
state_dict,
up_gate_proj_expert_weight_key,
down_proj_expert_weight_key,
state_dict, up_gate_proj_expert_weight_key, down_proj_expert_weight_key, is_rearrange
)
)
@@ -357,22 +355,62 @@ class CutlassW4A8MoEMethod(CutlassMoEMethod):
up_gate_proj_in_scale = []
down_proj_in_scale = []
if isinstance(state_dict, list):
state_dict = dict(state_dict)
if layer.ep_size > 1:
for expert_idx in ep_rank_to_expert_id_list:
scale_tensor = get_tensor(state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)])
scale_tensor = get_tensor(
get_tensor(
state_dict[up_gate_proj_expert_in_scale_key.format(expert_idx)]
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
up_gate_proj_in_scale_all_experts.append(scale_tensor)
for expert_idx in logical_expert_ids:
up_gate_proj_weight_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_weight_scale_key.format(expert_idx))
if up_gate_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_weight_scale.append(
get_tensor(state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(down_proj_expert_weight_scale_key.format(expert_idx))
if down_proj_expert_weight_scale_key.format(expert_idx) in state_dict
else down_proj_expert_weight_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
up_gate_proj_in_scale.append(
get_tensor(state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx)))
get_tensor(
(
state_dict.pop(up_gate_proj_expert_in_scale_key.format(expert_idx))
if up_gate_proj_expert_in_scale_key.format(expert_idx) in state_dict
else up_gate_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(
get_tensor(
(
state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))
if down_proj_expert_in_scale_key.format(expert_idx) in state_dict
else down_proj_expert_in_scale_key.format(expert_idx)
),
layer.fd_config.model_config.model,
)
)
down_proj_in_scale.append(get_tensor(state_dict.pop(down_proj_expert_in_scale_key.format(expert_idx))))
up_gate_proj_weight = paddle.stack(up_gate_proj_weights, axis=0)
down_proj_weight = paddle.stack(down_proj_weights, axis=0)