mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[RL]Support loading weights via the load_weights function for RL (#5549)
* RL support load_weights * fix
This commit is contained in:
@@ -140,7 +140,7 @@ def process_weight_transpose(layer, weight_name):
|
||||
default_initializer=paddle.nn.initializer.Constant(0),
|
||||
is_bias=False,
|
||||
)
|
||||
if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache:
|
||||
if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False):
|
||||
free_tensor(weight)
|
||||
setattr(layer, weight_name, weight_tmp)
|
||||
return
|
||||
|
||||
@@ -66,6 +66,7 @@ class RolloutModelConfig:
|
||||
num_nextn_predict_layers: int = 0,
|
||||
eplb_config: str = {},
|
||||
routing_replay_config: str = None,
|
||||
load_choices: str = "default_v1",
|
||||
):
|
||||
# Required parameters
|
||||
self.model = model_name_or_path
|
||||
@@ -115,6 +116,7 @@ class RolloutModelConfig:
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.eplb_config = eplb_config
|
||||
self.routing_replay_config = routing_replay_config
|
||||
self.load_choices = load_choices
|
||||
|
||||
def __str__(self):
|
||||
return "\n".join(f"{k}: {v}" for k, v in self.__dict__.items())
|
||||
|
||||
@@ -21,6 +21,7 @@ import paddle
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models.ernie4_5_moe import (
|
||||
Ernie4_5_MoeForCausalLM,
|
||||
Ernie4_5_MoePretrainedModel,
|
||||
@@ -50,6 +51,10 @@ from fastdeploy.model_executor.models.qwen3moe import (
|
||||
Qwen3MoeForCausalLM,
|
||||
Qwen3MoePretrainedModel,
|
||||
)
|
||||
from fastdeploy.model_executor.utils import (
|
||||
multi_switch_config_context,
|
||||
process_final_after_loading,
|
||||
)
|
||||
from fastdeploy.rl.rollout_config import RolloutModelConfig
|
||||
|
||||
|
||||
@@ -64,13 +69,33 @@ class RolloutModel(nn.Layer):
|
||||
|
||||
def _init_model(self) -> nn.Layer:
|
||||
"""Load model from loader based on config."""
|
||||
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||
return model_loader.load_model(fd_config=self.fd_config)
|
||||
|
||||
def load_weights(self, weights_iterator):
|
||||
"""Load weights_iterator."""
|
||||
|
||||
context = paddle.LazyGuard()
|
||||
architectures = f"{self.fd_config.model_config.architectures[0]}RL"
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(self.fd_config)
|
||||
if self.fd_config.quant_config is not None:
|
||||
quantization_context = multi_switch_config_context(
|
||||
(self.fd_config.quant_config, "is_checkpoint_bf16", True),
|
||||
(self.fd_config.load_config, "dynamic_load_weight", False),
|
||||
)
|
||||
else:
|
||||
# bf16
|
||||
quantization_context = multi_switch_config_context(
|
||||
(self.fd_config.load_config, "dynamic_load_weight", False)
|
||||
)
|
||||
with quantization_context:
|
||||
with context:
|
||||
model_cls = ModelRegistry.get_class(architectures)
|
||||
model = model_cls(self.fd_config)
|
||||
model.eval()
|
||||
return model
|
||||
model.load_weights(weights_iterator)
|
||||
if self.fd_config.speculative_config.model_type != "mtp":
|
||||
process_final_after_loading(model, self.fd_config)
|
||||
self.rollout_model = model
|
||||
|
||||
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
|
||||
"""Get parameter name mappings between rollout and training models."""
|
||||
|
||||
Reference in New Issue
Block a user