fix mapping (#3320)

This commit is contained in:
gaoziyuan
2025-08-12 16:15:59 +08:00
committed by GitHub
parent 283da92bfa
commit ccc7f1beb3

View File

@@ -89,6 +89,7 @@ class BaseRLModel(nn.Layer):
super(BaseRLModel, self).__init__()
self.infer_to_train_mapping = {}
self.fd_config = None
self._mappings_built = False
@classmethod
def name(cls) -> str:
@@ -145,6 +146,12 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel):
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
@@ -218,6 +225,11 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
@@ -319,6 +331,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel):
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
@@ -363,6 +380,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel):
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]
@@ -432,6 +454,11 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel):
return "Qwen3ForCausalLMRL"
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
if self._mappings_built:
return self.infer_to_train_mapping
self.infer_to_train_mapping = {}
self._mappings_built = True
# Prepare placeholders
place_holders = ["weight"]