mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
support qwen3moe name_mapping (#2820)
This commit is contained in:
@@ -303,7 +303,103 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
|
||||
|
||||
def get_name_mappings_to_training(self):
|
||||
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
|
||||
pass
|
||||
# Prepare placeholders
|
||||
place_holders = ["weight"]
|
||||
|
||||
# Initialize mapping dictionary
|
||||
infer_to_train = {}
|
||||
|
||||
infer_base_name = "model"
|
||||
train_base_name = "model"
|
||||
# Static mappings (non-layer specific)
|
||||
static_mappings = {
|
||||
f"{infer_base_name}.embeddings.word_embeddings.weight":
|
||||
f"{train_base_name}.embed_tokens.weight",
|
||||
f"{infer_base_name}.norm.ln_weight": f"{train_base_name}.norm.weight",
|
||||
"lm_head.out_linear.weight": "lm_head.weight"
|
||||
}
|
||||
infer_to_train.update(static_mappings)
|
||||
|
||||
infer_base_name = infer_base_name + ".layers"
|
||||
train_base_name = train_base_name + ".layers"
|
||||
|
||||
# Helper function to add layer mappings
|
||||
def _add_layer_mappings(layer_idx, is_moe_layer=False):
|
||||
# Handle special case for layer 0's input layernorm and attn o_proj
|
||||
for ph in place_holders:
|
||||
infer_key = f"{infer_base_name}.{layer_idx}.input_layernorm.ln_{ph}"
|
||||
train_key = f"{train_base_name}.{layer_idx}.input_layernorm.{ph}"
|
||||
infer_to_train[infer_key] = train_key
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.o_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.o_proj.{ph}"
|
||||
|
||||
# qwen q_norm/k_norm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.q_norm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.q_norm.{ph}"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.k_norm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.k_norm.{ph}"
|
||||
|
||||
# qwen qkv proj
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.self_attn.qkv_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.self_attn.qkv_proj.{ph}"
|
||||
|
||||
# Post-attention layernorm
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.post_attention_layernorm.ln_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.post_attention_layernorm.{ph}"
|
||||
|
||||
if not is_moe_layer:
|
||||
# FFN mappings
|
||||
for ph in place_holders:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_up_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
|
||||
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.down_proj.linear_{ph}"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.down_proj.{ph}"
|
||||
else:
|
||||
# MoE specific mappings
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.gate_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.gate.weight"
|
||||
|
||||
if self.fd_config.moe_config.moe_use_aux_free:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.fused_moe.gate_correction_bias"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.moe_statics.e_score_correction_bias"
|
||||
|
||||
# Support shared experts
|
||||
if self.fd_config.model_config.get(
|
||||
"moe_num_shared_experts", 0) > 0:
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.gate_up_proj.linear_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.up_gate_proj.weight"
|
||||
infer_to_train[f"{infer_base_name}.{layer_idx}.mlp.shared_experts.down_proj.linear_weight"] = \
|
||||
f"{train_base_name}.{layer_idx}.mlp.shared_experts.down_proj.weight"
|
||||
|
||||
# MoE experts mappings
|
||||
for expert_idx in range(self.fd_config.moe_config.num_experts):
|
||||
for ph in place_holders:
|
||||
# FFN1 (up_gate_proj)
|
||||
ffn1_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn1_weight"
|
||||
if ffn1_key not in infer_to_train:
|
||||
infer_to_train[ffn1_key] = []
|
||||
infer_to_train[ffn1_key].append(
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.up_gate_proj.{ph}"
|
||||
)
|
||||
|
||||
# FFN2 (down_proj)
|
||||
ffn2_key = f"{infer_base_name}.{layer_idx}.mlp.moe_ffn2_weight"
|
||||
if ffn2_key not in infer_to_train:
|
||||
infer_to_train[ffn2_key] = []
|
||||
infer_to_train[ffn2_key].append(
|
||||
f"{train_base_name}.{layer_idx}.mlp.experts.{expert_idx}.down_proj.{ph}"
|
||||
)
|
||||
|
||||
# Process MoE layers
|
||||
for layer_idx in range(self.fd_config.model_config.num_layers):
|
||||
_add_layer_mappings(layer_idx, is_moe_layer=True)
|
||||
|
||||
return infer_to_train
|
||||
|
||||
|
||||
class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
|
||||
|
Reference in New Issue
Block a user