[Cherry-Pick][RL]Fix RL load_weights #5642 (#5643)

This commit is contained in:
bukejiyu
2025-12-19 11:17:09 +08:00
committed by GitHub
parent 9c55bc31cd
commit 2aa88d3621

View File

@@ -91,10 +91,10 @@ class RolloutModel(nn.Layer):
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
model.eval()
model.load_weights(weights_iterator)
if self.fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, self.fd_config)
model.eval()
model.load_weights(weights_iterator)
if self.fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, self.fd_config)
self.rollout_model = model
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: