From 95a214ae439fe0806466db4d3242c0badada132c Mon Sep 17 00:00:00 2001 From: gaoziyuan <88373061+gzy19990617@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:12:55 +0800 Subject: [PATCH] support trainer_degree in name_mapping (#2935) --- fastdeploy/rl/rollout_model.py | 31 +++++++++++------------------ test/ci_use/EB_VL_Lite/baseline.txt | 1 + 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 41e9589e7..241c76df0 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -63,9 +63,9 @@ class RolloutModel(nn.Layer): model.eval() return model - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" - return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})() + return getattr(self.rollout_model, "get_name_mappings_to_training", lambda: {})(trainer_degree) def get_quantization_infer_keys(self) -> Dict[str, str]: """Get parameter name mappings between rollout and training models.""" @@ -108,9 +108,6 @@ class BaseRLModel(nn.Layer): # Skip weight scale parameters in mapping. Train and infer have same key. self.infer_to_train_mapping[key] = key - if getattr(self.fd_config.model_config, "tie_word_embeddings", False): - self.infer_to_train_mapping.pop("lm_head.linear.weight") - def get_quantization_infer_keys(self) -> list[str]: """Get quantization infer keys""" quant_weight_key = [] @@ -143,7 +140,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel): """name""" return "Ernie4_5_MoeForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -187,8 +184,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel): assert isinstance(self.fd_config.model_config.moe_layer_start_index, int) # Process MoE layers for layer_idx in range( - self.fd_config.model_config.moe_layer_start_index, - self.fd_config.model_config.num_hidden_layers, + self.fd_config.model_config.moe_layer_start_index, self.fd_config.model_config.num_hidden_layers ): _add_layer_mappings(layer_idx) @@ -216,7 +212,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener """name""" return "Ernie4_5_VLMoeForConditionalGenerationRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -249,10 +245,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener expert_mappings = defaultdict(list) for expert_idx in _generate_ranges( - expert_start, - total_moe_num, - expert_num_per_rank * 2, - expert_num_per_rank, + expert_start, total_moe_num, expert_num_per_rank * 2, expert_num_per_rank ): for ph in place_holders: expert_mappings[f"{base_name}.{layer_idx}.mlp.{moe_tag}_fused_moe.up_gate_proj_weight"].append( @@ -284,9 +277,9 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener assert isinstance(self.fd_config.model_config.moe_num_experts, list) total_moe_num = sum(self.fd_config.model_config.moe_num_experts) - rollout_model_degree = self.fd_config.parallel_config.tensor_parallel_size - expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // rollout_model_degree - + if not trainer_degree: + trainer_degree = self.fd_config.parallel_config.tensor_parallel_size + expert_num_per_rank = self.fd_config.model_config.moe_num_experts[0] // trainer_degree # Process MoE layers for layer_idx in range(text_moe_layer_start_index, text_moe_layer_end_index): _add_expert_mappings(layer_idx, "text", expert_start=0) @@ -317,7 +310,7 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel): """name""" return "Qwen2ForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -361,7 +354,7 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel): """name""" return "Qwen3MoeForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" # Prepare placeholders place_holders = ["weight"] @@ -431,5 +424,5 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel): """name""" return "Qwen3ForCausalLMRL" - def get_name_mappings_to_training(self) -> Dict[str, str]: + def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: pass diff --git a/test/ci_use/EB_VL_Lite/baseline.txt b/test/ci_use/EB_VL_Lite/baseline.txt index 537c3d823..bc1298e07 100644 --- a/test/ci_use/EB_VL_Lite/baseline.txt +++ b/test/ci_use/EB_VL_Lite/baseline.txt @@ -1008,6 +1008,7 @@ ernie.layers.27.post_attention_layernorm.weight ernie.norm.weight lm_head.linear.weight ernie.embed_tokens.embeddings.weight:ernie.embed_tokens.weight +lm_head.linear.weight:lm_head.weight ernie.layers.1.mlp.text_fused_moe.gate_weight:ernie.layers.1.mlp.gate.weight ernie.layers.1.mlp.text_fused_moe.gate_correction_bias:ernie.layers.1.mlp.moe_statics.e_score_correction_bias ernie.layers.1.mlp.text_fused_moe.up_gate_proj_weight:['ernie.layers.1.mlp.experts.0.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.1.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.2.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.3.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.4.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.5.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.6.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.7.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.8.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.9.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.10.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.11.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.12.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.13.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.14.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.15.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.16.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.17.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.18.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.19.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.20.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.21.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.22.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.23.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.24.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.25.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.26.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.27.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.28.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.29.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.30.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.31.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.64.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.65.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.66.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.67.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.68.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.69.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.70.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.71.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.72.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.73.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.74.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.75.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.76.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.77.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.78.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.79.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.80.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.81.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.82.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.83.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.84.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.85.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.86.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.87.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.88.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.89.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.90.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.91.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.92.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.93.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.94.up_gate_proj.weight', 'ernie.layers.1.mlp.experts.95.up_gate_proj.weight']