mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
qwen3_moe (#3084)
This commit is contained in:
@@ -117,13 +117,12 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
self.tp_size = fd_config.parallel_config.tensor_parallel_size
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.gate.weight",
|
||||
"gate_correction_bias_key": f"{prefix}.gate.e_score_correction_bias",
|
||||
"up_gate_proj_expert_weight_key": f"{prefix}.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{prefix}.experts.{{}}.down_proj.weight",
|
||||
}
|
||||
|
||||
self.fused_moe = FusedMoE(
|
||||
self.experts = FusedMoE(
|
||||
fd_config=fd_config,
|
||||
reduce_results=False,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
@@ -137,6 +136,16 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
fd_config=fd_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
input_size=fd_config.model_config.hidden_size,
|
||||
output_size=fd_config.model_config.n_routed_experts,
|
||||
with_bias=False,
|
||||
skip_quant=True,
|
||||
weight_dtype="float32",
|
||||
)
|
||||
|
||||
self.num_shared_experts = fd_config.model_config.n_shared_experts
|
||||
shared_experts_intermediate_size = self.num_shared_experts * fd_config.model_config.moe_intermediate_size
|
||||
|
||||
@@ -149,13 +158,14 @@ class DeepSeekV3MoE(nn.Layer):
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
""" """
|
||||
self.fused_moe.load_state_dict(state_dict)
|
||||
self.gate.load_state_dict(state_dict)
|
||||
self.experts.load_state_dict(state_dict)
|
||||
self.shared_experts.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor):
|
||||
""" """
|
||||
shared_experts_out = self.shared_experts(hidden_states)
|
||||
moe_out = self.fused_moe(hidden_states)
|
||||
moe_out = self.experts(hidden_states, self.gate)
|
||||
moe_out = moe_out + shared_experts_out
|
||||
# We do to TP all reduce after the sum of experts.
|
||||
if self.tp_size > 1:
|
||||
|
Reference in New Issue
Block a user