Fix rollout_model init (#2881)

This commit is contained in:
Yuanle Liu
2025-07-17 13:36:21 +08:00
committed by GitHub
parent 1f15ca21e4
commit dbb9e2506b
9 changed files with 76 additions and 312 deletions

View File

@@ -39,17 +39,17 @@ class RolloutModel(nn.Layer):
"""Initialize with FastDeploy configuration."""
super(RolloutModel, self).__init__()
self.fd_config = rollout_model_config.initialize()
self._init_model()
self.rollout_model = self._init_model()
def _init_model(self):
def _init_model(self) -> nn.Layer:
"""Load model from loader based on config."""
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
self.rollout_model = model.eval()
model.eval()
return model
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""
@@ -74,15 +74,14 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
super(Ernie4_5_MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Ernie4_5_MoeForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
place_holders = ["weight"]
# Initialize mapping dictionary
infer_to_train = {}
@@ -94,7 +93,7 @@ class Ernie4_5_MoeForCausalLMRL(Ernie4_5_MoeForCausalLM):
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("tie_word_embeddings", False):
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.linear.weight")
@@ -153,15 +152,14 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
super(Ernie4_5_VLMoeForConditionalGenerationRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Ernie4_5_VLMoeForConditionalGenerationRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
have_bias = self.fd_config.model_config.get("have_norm_bias", False)
# Prepare placeholders
place_holders = ["weight"] + (["bias"] if have_bias else [])
place_holders = ["weight"]
# Initialize mapping dictionary
infer_to_train = {}
@@ -173,7 +171,7 @@ class Ernie4_5_VLMoeForConditionalGenerationRL(Ernie4_5_VLMoeForConditionalGener
f"{base_name}.embed_tokens.weight",
"lm_head.linear.weight": "lm_head.weight"
}
if self.fd_config.model_config.get("tie_word_embeddings", False):
if getattr(self.fd_config.model_config, "tie_word_embeddings", False):
# Support tie_word_embeddings
logger.debug("enable tie_word_embeddings")
static_mappings.pop("lm_head.linear.weight")
@@ -257,11 +255,11 @@ class Qwen2ForCausalLMRL(Qwen2ForCausalLM):
super(Qwen2ForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen2ForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
@@ -307,11 +305,11 @@ class Qwen3MoeForCausalLMRL(Qwen3MoeForCausalLM):
super(Qwen3MoeForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen3MoeForCausalLMRL"
def get_name_mappings_to_training(self):
def get_name_mappings_to_training(self) -> Dict[str, str]:
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
# Prepare placeholders
place_holders = ["weight"]
@@ -379,6 +377,6 @@ class Qwen3ForCausalLMRL(Qwen3ForCausalLM):
super(Qwen3ForCausalLMRL, self).__init__(fd_config)
@classmethod
def name(self):
def name(self) -> str:
"""name"""
return "Qwen3ForCausalLMRL"