diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 67fff22c3..af9c8a346 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -86,6 +86,7 @@ class BaseRLModel(nn.Layer): super(BaseRLModel, self).__init__() self.infer_to_train_mapping = {} self.fd_config = None + self._mappings_built = False @classmethod def name(cls) -> str: @@ -142,6 +143,12 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True + # Prepare placeholders place_holders = ["weight"] @@ -215,6 +222,11 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -316,6 +328,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -360,6 +377,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM, BaseRLModel): def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Generate mapping between inference and training parameter for RL(donot delete!).""" + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"] @@ -429,6 +451,11 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM, BaseRLModel): return "Qwen3ForCausalLMRL" def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: + if self._mappings_built: + return self.infer_to_train_mapping + + self.infer_to_train_mapping = {} + self._mappings_built = True # Prepare placeholders place_holders = ["weight"]