mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
This commit is contained in:
@@ -228,19 +228,53 @@ class FusedMoE(nn.Layer):
|
||||
if is_ffn_merged:
|
||||
for i in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + i
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
up_gate_proj_expert_weight_key_name = up_gate_proj_expert_weight_key.format(expert_idx)
|
||||
up_gate_proj_weights.append(
|
||||
get_tensor(state_dict.pop(up_gate_proj_expert_weight_key.format(expert_idx)))
|
||||
get_tensor(
|
||||
state_dict.pop(up_gate_proj_expert_weight_key_name)
|
||||
if up_gate_proj_expert_weight_key_name in state_dict
|
||||
else up_gate_proj_expert_weight_key_name,
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
)
|
||||
down_proj_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(down_proj_expert_weight_key_name)
|
||||
if down_proj_expert_weight_key_name in state_dict
|
||||
else down_proj_expert_weight_key_name,
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
)
|
||||
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
|
||||
else:
|
||||
gate_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "gate_proj")
|
||||
up_expert_weight_key = up_gate_proj_expert_weight_key.replace("up_gate_proj", "up_proj")
|
||||
for j in range(self.num_local_experts):
|
||||
expert_idx = self.expert_id_offset + j
|
||||
gate = get_tensor(state_dict.pop(gate_expert_weight_key.format(expert_idx)))
|
||||
up = get_tensor(state_dict.pop(up_expert_weight_key.format(expert_idx)))
|
||||
gate_expert_weight_key_name = gate_expert_weight_key.format(expert_idx)
|
||||
up_expert_weight_key_name = up_expert_weight_key.format(expert_idx)
|
||||
down_proj_expert_weight_key_name = down_proj_expert_weight_key.format(expert_idx)
|
||||
gate = get_tensor(
|
||||
state_dict.pop(gate_expert_weight_key_name)
|
||||
if gate_expert_weight_key_name in state_dict
|
||||
else gate_expert_weight_key_name,
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
up = get_tensor(
|
||||
state_dict.pop(up_expert_weight_key_name)
|
||||
if up_expert_weight_key_name in state_dict
|
||||
else up_expert_weight_key_name,
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
up_gate_proj_weights.append(paddle.concat([gate, up], axis=-1))
|
||||
down_proj_weights.append(get_tensor(state_dict.pop(down_proj_expert_weight_key.format(expert_idx))))
|
||||
down_proj_weights.append(
|
||||
get_tensor(
|
||||
state_dict.pop(down_proj_expert_weight_key_name)
|
||||
if down_proj_expert_weight_key_name in state_dict
|
||||
else down_proj_expert_weight_key_name,
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
)
|
||||
)
|
||||
return up_gate_proj_weights, down_proj_weights
|
||||
|
||||
def extract_moe_ffn_weights(self, state_dict: dict):
|
||||
|
Reference in New Issue
Block a user