This commit is contained in:
bukejiyu
2025-08-06 14:45:27 +08:00
committed by GitHub
parent 91dc87f1c5
commit 20839abccf
30 changed files with 1361 additions and 1087 deletions

View File

@@ -117,13 +117,12 @@ class DeepSeekV3MoE(nn.Layer):
self.tp_size = fd_config.parallel_config.tensor_parallel_size
weight_key_map = {
"gate_weight_key": f"{prefix}.gate.weight",
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
}
self.fused_moe = FusedMoE(
self.experts = FusedMoE(
fd_config=fd_config,
reduce_results=False,
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
@@ -137,6 +136,16 @@ class DeepSeekV3MoE(nn.Layer):
weight_key_map=weight_key_map,
)
self.gate = ReplicatedLinear(
fd_config=fd_config,
prefix=f"{prefix}.gate",
input_size=fd_config.model_config.hidden_size,
output_size=fd_config.model_config.n_routed_experts,
with_bias=False,
skip_quant=True,
weight_dtype="float32",
)
self.num_shared_experts = fd_config.model_config.n_shared_experts
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
@@ -149,13 +158,14 @@ class DeepSeekV3MoE(nn.Layer):
def load_state_dict(self, state_dict):
""" """
self.fused_moe.load_state_dict(state_dict)
self.gate.load_state_dict(state_dict)
self.experts.load_state_dict(state_dict)
self.shared_experts.load_state_dict(state_dict)
def forward(self, hidden_states: paddle.Tensor):
""" """
shared_experts_out = self.shared_experts(hidden_states)
moe_out = self.fused_moe(hidden_states)
moe_out = self.experts(hidden_states, self.gate)
moe_out = moe_out + shared_experts_out
# We do to TP all reduce after the sum of experts.
if self.tp_size > 1: