mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
[MTP][RL]support rl reshard wenxin-tools-145 (#4173)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support mtp reshard in rl mode * fix function
This commit is contained in:
@@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
context = paddle.LazyGuard()
|
||||
else:
|
||||
|
@@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
||||
# register rl model
|
||||
import fastdeploy.rl # noqa
|
||||
|
||||
if fd_config.speculative_config.model_type != "mtp":
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||
else:
|
||||
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||
|
||||
architectures = architectures + "RL"
|
||||
|
||||
with context:
|
||||
|
@@ -17,11 +17,10 @@
|
||||
import os
|
||||
import time
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
@@ -30,7 +29,7 @@ from fastdeploy.config import FDConfig
|
||||
class DynamicWeightManager:
|
||||
"""Manages model weights loading, updating and shared state across processes."""
|
||||
|
||||
def __init__(self, fd_config: FDConfig, model: nn.Layer):
|
||||
def __init__(self, fd_config: FDConfig, models):
|
||||
"""Initialize with config and model instances."""
|
||||
self.fd_config = fd_config
|
||||
self.load_config = fd_config.load_config
|
||||
@@ -41,7 +40,10 @@ class DynamicWeightManager:
|
||||
self.meta_src_id = self._get_gpu_id()
|
||||
self.first_load = True
|
||||
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
||||
self.model: nn.Layer = model
|
||||
if not isinstance(models, List):
|
||||
self.model_list = [models]
|
||||
else:
|
||||
self.model_list = models
|
||||
self._capture_model_state()
|
||||
self.update_parameters()
|
||||
self.finalize_update()
|
||||
@@ -54,9 +56,10 @@ class DynamicWeightManager:
|
||||
@paddle.no_grad()
|
||||
def _capture_model_state(self):
|
||||
"""Capture and store initial model parameters state."""
|
||||
for name, param in self.model.state_dict().items():
|
||||
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
for model in self.model_list:
|
||||
for name, param in model.state_dict().items():
|
||||
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||
self.state_dict[name] = param
|
||||
|
||||
def update_parameters(self, pid: int = 0) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
@@ -133,8 +136,9 @@ class DynamicWeightManager:
|
||||
|
||||
paddle.device.cuda.empty_cache()
|
||||
# step2: release model weight
|
||||
for param in self.model.state_dict().values():
|
||||
param._clear_data()
|
||||
for model in self.model_list:
|
||||
for param in model.state_dict().values():
|
||||
param._clear_data()
|
||||
|
||||
self._verify_parameters("clearance")
|
||||
|
||||
|
Reference in New Issue
Block a user