diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index b7f192051..fd4165174 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -89,6 +89,7 @@ class BaseRLModel(nn.Layer): super(BaseRLModel, self).__init__() self.infer_to_train_mapping = {} self.fd_config = None + self._mappings_built = False @classmethod def name(cls) -> str: @@ -145,6 +146,12 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True + # Prepare placeholders place_holders = ["weight"] @@ -218,6 +225,11 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -319,6 +331,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -363,6 +380,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -432,6 +454,11 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel): return "Qwen3ForCausalLMRL" def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"]