[MTP][RL]support rl reshard wenxin-tools-145 (#4173)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* support mtp reshard in rl mode

* fix function
This commit is contained in:
freeliuzc
2025-09-23 20:40:26 +08:00
committed by GitHub
parent 389c5dd3a2
commit 94b6e7a341
3 changed files with 23 additions and 9 deletions

View File

@@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader):
# register rl model
import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL"
context = paddle.LazyGuard()
else:

View File

@@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
# register rl model
import fastdeploy.rl # noqa
if fd_config.speculative_config.model_type != "mtp":
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
else:
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
architectures = architectures + "RL"
with context:

View File

@@ -17,11 +17,10 @@
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict
from typing import Any, Dict, List
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
@@ -30,7 +29,7 @@ from fastdeploy.config import FDConfig
class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, model: nn.Layer):
def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances."""
self.fd_config = fd_config
self.load_config = fd_config.load_config
@@ -41,7 +40,10 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.model: nn.Layer = model
if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state()
self.update_parameters()
self.finalize_update()
@@ -54,9 +56,10 @@ class DynamicWeightManager:
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for name, param in self.model.state_dict().items():
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
for model in self.model_list:
for name, param in model.state_dict().items():
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
@@ -133,8 +136,9 @@ class DynamicWeightManager:
paddle.device.cuda.empty_cache()
# step2: release model weight
for param in self.model.state_dict().values():
param._clear_data()
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance")