support trainer_degree in name_mapping (#2935)

This commit is contained in:
gaoziyuan
2025-07-21 14:12:55 +08:00
committed by GitHub
parent bce2c6cd7c
commit 95a214ae43
2 changed files with 13 additions and 19 deletions

View File

@@ -63,9 +63,9 @@ class RolloutModel(nn.Layer):
model.eval() model.eval()
return model return model
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models.""" """Get parameter name mappings between rollout and training models."""
return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})() return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree)
def get_quantization_infer_keys(self) -> Dict[str, str]: def get_quantization_infer_keys(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models.""" """Get parameter name mappings between rollout and training models."""
@@ -108,9 +108,6 @@ class BaseRLModel(nn.Layer):
# Skip weight scale parameters in mapping. Train and infer have same key. # Skip weight scale parameters in mapping. Train and infer have same key.
self.infer_to_train_mapping[key] = key self.infer_to_train_mapping[key] = key
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
self.infer_to_train_mapping.pop("lm_head.linear.weight")
def get_quantization_infer_keys(self) -> list[str]: def get_quantization_infer_keys(self) -> list[str]:
"""Get quantization infer keys""" """Get quantization infer keys"""
quant_weight_key = [] quant_weight_key = []
@@ -143,7 +140,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
"""name""" """name"""
return "Ernie4_5_MoeForCausalLMRL" return "Ernie4_5_MoeForCausalLMRL"
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!).""" """Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders # Prepare placeholders
place_holders = ["weight"] place_holders = ["weight"]
@@ -187,8 +184,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
assert isinstance(self.fd_config.model_config.moe_layer_start_index, int) assert isinstance(self.fd_config.model_config.moe_layer_start_index, int)
# Process MoE layers # Process MoE layers
for layer_idx in range( for layer_idx in range(
self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers
self.fd_config.model_config.num_hidden_layers,
): ):
_add_layer_mappings(layer_idx) _add_layer_mappings(layer_idx)
@@ -216,7 +212,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
"""name""" """name"""
return "Ernie4_5_VLMoeForConditionalGenerationRL" return "Ernie4_5_VLMoeForConditionalGenerationRL"
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!).""" """Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders # Prepare placeholders
place_holders = ["weight"] place_holders = ["weight"]
@@ -249,10 +245,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
expert_mappings = defaultdict(list) expert_mappings = defaultdict(list)
for expert_idx in _generate_ranges( for expert_idx in _generate_ranges(
expert_start, expert_start, total_moe_num, expert_num_per_rank * 2, expert_num_per_rank
total_moe_num,
expert_num_per_rank * 2,
expert_num_per_rank,
): ):
for ph in place_holders: for ph in place_holders:
expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append( expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append(
@@ -284,9 +277,9 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
assert isinstance(self.fd_config.model_config.moe_num_experts, list) assert isinstance(self.fd_config.model_config.moe_num_experts, list)
total_moe_num = sum(self.fd_config.model_config.moe_num_experts) total_moe_num = sum(self.fd_config.model_config.moe_num_experts)
rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size if not trainer_degree:
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // rollout_model_degree trainer_degree = self.fd_config.parallel_config.tensor_parallel_size
expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree
# Process MoE layers # Process MoE layers
for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index): for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index):
_add_expert_mappings(layer_idx, "text", expert_start=0) _add_expert_mappings(layer_idx, "text", expert_start=0)
@@ -317,7 +310,7 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel):
"""name""" """name"""
return "Qwen2ForCausalLMRL" return "Qwen2ForCausalLMRL"
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!).""" """Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders # Prepare placeholders
place_holders = ["weight"] place_holders = ["weight"]
@@ -361,7 +354,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
"""name""" """name"""
return "Qwen3MoeForCausalLMRL" return "Qwen3MoeForCausalLMRL"
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!).""" """Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders # Prepare placeholders
place_holders = ["weight"] place_holders = ["weight"]
@@ -431,5 +424,5 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
"""name""" """name"""
return "Qwen3ForCausalLMRL" return "Qwen3ForCausalLMRL"
def get_name_mappings_to_training(self) -> Dict[str, str]: def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
pass pass

View File

@@ -1008,6 +1008,7 @@ ernie.layers.27.post_attention_layernorm.weight
ernie.norm.weight ernie.norm.weight
lm_head.linear.weight lm_head.linear.weight
ernie.embed_tokens.embeddings.weight:ernie.embed_tokens.weight ernie.embed_tokens.embeddings.weight:ernie.embed_tokens.weight
lm_head.linear.weight:lm_head.weight
ernie.layers.1.mlp.text_fused_moe.gate_weight:ernie.layers.1.mlp.gate.weight ernie.layers.1.mlp.text_fused_moe.gate_weight:ernie.layers.1.mlp.gate.weight
ernie.layers.1.mlp.text_fused_moe.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias ernie.layers.1.mlp.text_fused_moe.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias
ernie.layers.1.mlp.text_fused_moe.up_gate_proj_weight:['ernie.layers.1.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.95.up_gate_proj.weight'] ernie.layers.1.mlp.text_fused_moe.up_gate_proj_weight:['ernie.layers.1.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.95.up_gate_proj.weight']