[RL]Support loading weights via the load_weights function for RL (#5549)

* RL support load_weights

* fix
This commit is contained in:
bukejiyu
2025-12-18 18:27:05 +08:00
committed by GitHub
parent ac013803f3
commit 4aa2c6871b
3 changed files with 32 additions and 5 deletions

View File

@@ -140,7 +140,7 @@ def process_weight_transpose(layer, weight_name):
default_initializer=paddle.nn.initializer.Constant(0),
is_bias=False,
)
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False):
free_tensor(weight)
setattr(layer, weight_name, weight_tmp)
return

View File

@@ -66,6 +66,7 @@ class RolloutModelConfig:
num_nextn_predict_layers: int = 0,
eplb_config: str = {},
routing_replay_config: str = None,
load_choices: str = "default_v1",
):
# Required parameters
self.model = model_name_or_path
@@ -115,6 +116,7 @@ class RolloutModelConfig:
self.num_nextn_predict_layers = num_nextn_predict_layers
self.eplb_config = eplb_config
self.routing_replay_config = routing_replay_config
self.load_choices = load_choices
def __str__(self):
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())

View File

@@ -21,6 +21,7 @@ import paddle
from paddle import nn
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models.ernie4_5_moe import (
Ernie4_5_MoeForCausalLM,
Ernie4_5_MoePretrainedModel,
@@ -50,6 +51,10 @@ from fastdeploy.model_executor.models.qwen3moe import (
Qwen3MoeForCausalLM,
Qwen3MoePretrainedModel,
)
from fastdeploy.model_executor.utils import (
multi_switch_config_context,
process_final_after_loading,
)
from fastdeploy.rl.rollout_config import RolloutModelConfig
@@ -64,13 +69,33 @@ class RolloutModel(nn.Layer):
def _init_model(self) -> nn.Layer:
"""Load model from loader based on config."""
model_loader = get_model_loader(load_config=self.fd_config.load_config)
return model_loader.load_model(fd_config=self.fd_config)
def load_weights(self, weights_iterator):
"""Load weights_iterator."""
context = paddle.LazyGuard()
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
if self.fd_config.quant_config is not None:
quantization_context = multi_switch_config_context(
(self.fd_config.quant_config, "is_checkpoint_bf16", True),
(self.fd_config.load_config, "dynamic_load_weight", False),
)
else:
# bf16
quantization_context = multi_switch_config_context(
(self.fd_config.load_config, "dynamic_load_weight", False)
)
with quantization_context:
with context:
model_cls = ModelRegistry.get_class(architectures)
model = model_cls(self.fd_config)
model.eval()
return model
model.load_weights(weights_iterator)
if self.fd_config.speculative_config.model_type != "mtp":
process_final_after_loading(model, self.fd_config)
self.rollout_model = model
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
"""Get parameter name mappings between rollout and training models."""