diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index 1ca45171f..279d58db3 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -91,10 +91,10 @@ class RolloutModel(nn.Layer): with context: model_cls = ModelRegistry.get_class(architectures) model = model_cls(self.fd_config) - model.eval() - model.load_weights(weights_iterator) - if self.fd_config.speculative_config.model_type != "mtp": - process_final_after_loading(model, self.fd_config) + model.eval() + model.load_weights(weights_iterator) + if self.fd_config.speculative_config.model_type != "mtp": + process_final_after_loading(model, self.fd_config) self.rollout_model = model def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: