mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
adapter qwen3 moe attr for init (#3066)
adapter qwen3 moe attr for init
This commit is contained in:
@@ -97,29 +97,26 @@ class Qwen3DecoderLayer(nn.Layer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_id = int(prefix.split(sep=".")[-1])
|
||||
|
||||
layer_id = int(prefix.split(sep=".")[-1])
|
||||
self.self_attn = Qwen3Attention(
|
||||
fd_config=fd_config,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
weight_key_map = {
|
||||
"gate_weight_key": f"{prefix}.mlp.gate.weight",
|
||||
"up_gate_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.up_gate_proj.weight",
|
||||
"down_proj_expert_weight_key": f"{prefix}.mlp.experts.{{}}.down_proj.weight",
|
||||
}
|
||||
|
||||
if (
|
||||
fd_config.model_config.moe_num_experts is not None
|
||||
and layer_id >= fd_config.model_config.moe_layer_start_index
|
||||
):
|
||||
|
||||
if fd_config.model_config.num_experts is not None and layer_id >= fd_config.model_config.moe_layer_start_index:
|
||||
self.mlp = FusedMoE(
|
||||
fd_config,
|
||||
moe_intermediate_size=fd_config.model_config.moe_intermediate_size,
|
||||
num_experts=fd_config.model_config.moe_num_experts,
|
||||
top_k=fd_config.model_config.moe_topk,
|
||||
num_experts=fd_config.model_config.num_experts,
|
||||
top_k=fd_config.model_config.num_experts_per_tok,
|
||||
layer_idx=layer_id,
|
||||
weight_key_map=weight_key_map,
|
||||
)
|
||||
@@ -386,12 +383,12 @@ class Qwen3MoePretrainedModel(PretrainedModel):
|
||||
return final_actions
|
||||
|
||||
num_experts = 0
|
||||
if isinstance(config.moe_num_experts, list):
|
||||
num_experts = sum(config.moe_num_experts)
|
||||
elif isinstance(config.moe_num_experts, int):
|
||||
num_experts = config.moe_num_experts
|
||||
if isinstance(config.num_experts, list):
|
||||
num_experts = sum(config.num_experts)
|
||||
elif isinstance(config.num_experts, int):
|
||||
num_experts = config.num_experts
|
||||
else:
|
||||
raise ValueError(f"Not support type of num_experts [{type(config.moe_num_experts)}]")
|
||||
raise ValueError(f"Not support type of num_experts [{type(config.num_experts)}]")
|
||||
|
||||
mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers, num_experts)
|
||||
|
||||
|
Reference in New Issue
Block a user