mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
@@ -91,10 +91,10 @@ class RolloutModel(nn.Layer):
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(self.fd_config)
|
||||
model.eval()
|
||||
model.load_weights(weights_iterator)
|
||||
if self.fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, self.fd_config)
|
||||
model.eval()
|
||||
model.load_weights(weights_iterator)
|
||||
if self.fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, self.fd_config)
|
||||
self.rollout_model = model
|
||||
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
|
||||
Reference in New Issue
Block a user