From 646d1a0aa25670926471e3422b60416a82e80087 Mon Sep 17 00:00:00 2001 From: bukejiyu <52310069+bukejiyu@users.noreply.github.com> Date: Thu, 18 Dec 2025 18:28:53 +0800 Subject: [PATCH] [Cherry-Pick][RL]Support loading weights via the load_weights function for RL #5549 (#5602) * RL support load_weights * fix --- fastdeploy/model_executor/utils.py | 2 +- fastdeploy/rl/rollout_config.py | 2 ++ fastdeploy/rl/rollout_model.py | 33 ++++++++++++++++++++++++++---- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 971ee58ae..4cbdf53d3 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -140,7 +140,7 @@ def process_weight_transpose(layer, weight_name): default_initializer=paddle.nn.initializer.Constant(0), is_bias=False, ) - if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache: + if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False): free_tensor(weight) setattr(layer, weight_name, weight_tmp) return diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index f7ff748fe..47db59a1c 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -66,6 +66,7 @@ class RolloutModelConfig: num_nextn_predict_layers: int = 0, eplb_config: str = {}, routing_replay_config: str = None, + load_choices: str = "default_v1", ): # Required parameters self.model = model_name_or_path @@ -115,6 +116,7 @@ class RolloutModelConfig: self.num_nextn_predict_layers = num_nextn_predict_layers self.eplb_config = eplb_config self.routing_replay_config = routing_replay_config + self.load_choices = load_choices def __str__(self): return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items()) diff --git a/fastdeploy/rl/rollout_model.py b/fastdeploy/rl/rollout_model.py index e9410d972..1ca45171f 100644 --- a/fastdeploy/rl/rollout_model.py +++ b/fastdeploy/rl/rollout_model.py @@ -21,6 +21,7 @@ import paddle from paddle import nn from fastdeploy.config import FDConfig +from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models.ernie4_5_moe import ( Ernie4_5_MoeForCausalLM, Ernie4_5_MoePretrainedModel, @@ -50,6 +51,10 @@ from fastdeploy.model_executor.models.qwen3moe import ( Qwen3MoeForCausalLM, Qwen3MoePretrainedModel, ) +from fastdeploy.model_executor.utils import ( + multi_switch_config_context, + process_final_after_loading, +) from fastdeploy.rl.rollout_config import RolloutModelConfig @@ -64,13 +69,33 @@ class RolloutModel(nn.Layer): def _init_model(self) -> nn.Layer: """Load model from loader based on config.""" + model_loader = get_model_loader(load_config=self.fd_config.load_config) + return model_loader.load_model(fd_config=self.fd_config) + + def load_weights(self, weights_iterator): + """Load weights_iterator.""" + context = paddle.LazyGuard() architectures = f"{self.fd_config.model_config.architectures[0]}RL" - with context: - model_cls = ModelRegistry.get_class(architectures) - model = model_cls(self.fd_config) + if self.fd_config.quant_config is not None: + quantization_context = multi_switch_config_context( + (self.fd_config.quant_config, "is_checkpoint_bf16", True), + (self.fd_config.load_config, "dynamic_load_weight", False), + ) + else: + # bf16 + quantization_context = multi_switch_config_context( + (self.fd_config.load_config, "dynamic_load_weight", False) + ) + with quantization_context: + with context: + model_cls = ModelRegistry.get_class(architectures) + model = model_cls(self.fd_config) model.eval() - return model + model.load_weights(weights_iterator) + if self.fd_config.speculative_config.model_type != "mtp": + process_final_after_loading(model, self.fd_config) + self.rollout_model = model def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]: """Get parameter name mappings between rollout and training models."""