update (#2978)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
bukejiyu
2025-07-24 00:16:42 +08:00
committed by GitHub
parent 85a78d695d
commit bfeb664ab8
3 changed files with 57 additions and 28 deletions

View File

@@ -228,19 +228,53 @@ class FusedMoE(nn.Layer):
if is_ffn_merged:
for i in range(self.num_local_experts):
expert_idx = self.expert_id_offset + i
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
up_gate_proj_weights.append(
get_tensor(state_dict.pop(up_gate_proj_expert_weight_key.format(expert_idx)))
get_tensor(
state_dict.pop(up_gate_proj_expert_weight_key_name)
if up_gate_proj_expert_weight_key_name in state_dict
else up_gate_proj_expert_weight_key_name,
self.fd_config.parallel_config.model_name_or_path,
)
)
down_proj_weights.append(
get_tensor(
state_dict.pop(down_proj_expert_weight_key_name)
if down_proj_expert_weight_key_name in state_dict
else down_proj_expert_weight_key_name,
self.fd_config.parallel_config.model_name_or_path,
)
)
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
else:
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
for j in range(self.num_local_experts):
expert_idx = self.expert_id_offset + j
gate = get_tensor(state_dict.pop(gate_expert_weight_key.format(expert_idx)))
up = get_tensor(state_dict.pop(up_expert_weight_key.format(expert_idx)))
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
gate = get_tensor(
state_dict.pop(gate_expert_weight_key_name)
if gate_expert_weight_key_name in state_dict
else gate_expert_weight_key_name,
self.fd_config.parallel_config.model_name_or_path,
)
up = get_tensor(
state_dict.pop(up_expert_weight_key_name)
if up_expert_weight_key_name in state_dict
else up_expert_weight_key_name,
self.fd_config.parallel_config.model_name_or_path,
)
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
down_proj_weights.append(
get_tensor(
state_dict.pop(down_proj_expert_weight_key_name)
if down_proj_expert_weight_key_name in state_dict
else down_proj_expert_weight_key_name,
self.fd_config.parallel_config.model_name_or_path,
)
)
return up_gate_proj_weights, down_proj_weights
def extract_moe_ffn_weights(self, state_dict: dict):