[MTP][RL]support rl reshard wenxin-tools-145 (#4173)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* support mtp reshard in rl mode

* fix function
This commit is contained in:
freeliuzc
2025-09-23 20:40:26 +08:00
committed by GitHub
parent 389c5dd3a2
commit 94b6e7a341
3 changed files with 23 additions and 9 deletions

View File

@@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader):
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL" architectures = architectures + "RL"
context = paddle.LazyGuard() context = paddle.LazyGuard()
else: else:

View File

@@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
# register rl model # register rl model
import fastdeploy.rl # noqa import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL" architectures = architectures + "RL"
with context: with context:

View File

@@ -17,11 +17,10 @@
import os import os
import time import time
from multiprocessing.shared_memory import SharedMemory from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict from typing import Any, Dict, List
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
@@ -30,7 +29,7 @@ from fastdeploy.config import FDConfig
class DynamicWeightManager: class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes.""" """Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, model: nn.Layer): def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances.""" """Initialize with config and model instances."""
self.fd_config = fd_config self.fd_config = fd_config
self.load_config = fd_config.load_config self.load_config = fd_config.load_config
@@ -41,7 +40,10 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id() self.meta_src_id = self._get_gpu_id()
self.first_load = True self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.model: nn.Layer = model if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state() self._capture_model_state()
self.update_parameters() self.update_parameters()
self.finalize_update() self.finalize_update()
@@ -54,9 +56,10 @@ class DynamicWeightManager:
@paddle.no_grad() @paddle.no_grad()
def _capture_model_state(self): def _capture_model_state(self):
"""Capture and store initial model parameters state.""" """Capture and store initial model parameters state."""
for name, param in self.model.state_dict().items(): for model in self.model_list:
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") for name, param in model.state_dict().items():
self.state_dict[name] = param logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None: def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy.""" """Core method to update model parameters based on strategy."""
@@ -133,8 +136,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# step2: release model weight # step2: release model weight
for param in self.model.state_dict().values(): for model in self.model_list:
param._clear_data() for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance") self._verify_parameters("clearance")