mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
[MTP][RL]support rl reshard wenxin-tools-145 (#4173)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* support mtp reshard in rl mode * fix function
This commit is contained in:
@@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
|
if fd_config.speculative_config.model_type != "mtp":
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||||
|
else:
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
context = paddle.LazyGuard()
|
context = paddle.LazyGuard()
|
||||||
else:
|
else:
|
||||||
|
@@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader):
|
|||||||
# register rl model
|
# register rl model
|
||||||
import fastdeploy.rl # noqa
|
import fastdeploy.rl # noqa
|
||||||
|
|
||||||
|
if fd_config.speculative_config.model_type != "mtp":
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
|
||||||
|
else:
|
||||||
|
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
|
||||||
|
|
||||||
architectures = architectures + "RL"
|
architectures = architectures + "RL"
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
|
@@ -17,11 +17,10 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from multiprocessing.shared_memory import SharedMemory
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
|
||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
@@ -30,7 +29,7 @@ from fastdeploy.config import FDConfig
|
|||||||
class DynamicWeightManager:
|
class DynamicWeightManager:
|
||||||
"""Manages model weights loading, updating and shared state across processes."""
|
"""Manages model weights loading, updating and shared state across processes."""
|
||||||
|
|
||||||
def __init__(self, fd_config: FDConfig, model: nn.Layer):
|
def __init__(self, fd_config: FDConfig, models):
|
||||||
"""Initialize with config and model instances."""
|
"""Initialize with config and model instances."""
|
||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
self.load_config = fd_config.load_config
|
self.load_config = fd_config.load_config
|
||||||
@@ -41,7 +40,10 @@ class DynamicWeightManager:
|
|||||||
self.meta_src_id = self._get_gpu_id()
|
self.meta_src_id = self._get_gpu_id()
|
||||||
self.first_load = True
|
self.first_load = True
|
||||||
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
||||||
self.model: nn.Layer = model
|
if not isinstance(models, List):
|
||||||
|
self.model_list = [models]
|
||||||
|
else:
|
||||||
|
self.model_list = models
|
||||||
self._capture_model_state()
|
self._capture_model_state()
|
||||||
self.update_parameters()
|
self.update_parameters()
|
||||||
self.finalize_update()
|
self.finalize_update()
|
||||||
@@ -54,9 +56,10 @@ class DynamicWeightManager:
|
|||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def _capture_model_state(self):
|
def _capture_model_state(self):
|
||||||
"""Capture and store initial model parameters state."""
|
"""Capture and store initial model parameters state."""
|
||||||
for name, param in self.model.state_dict().items():
|
for model in self.model_list:
|
||||||
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
for name, param in model.state_dict().items():
|
||||||
self.state_dict[name] = param
|
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
|
||||||
|
self.state_dict[name] = param
|
||||||
|
|
||||||
def update_parameters(self, pid: int = 0) -> None:
|
def update_parameters(self, pid: int = 0) -> None:
|
||||||
"""Core method to update model parameters based on strategy."""
|
"""Core method to update model parameters based on strategy."""
|
||||||
@@ -133,8 +136,9 @@ class DynamicWeightManager:
|
|||||||
|
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
# step2: release model weight
|
# step2: release model weight
|
||||||
for param in self.model.state_dict().values():
|
for model in self.model_list:
|
||||||
param._clear_data()
|
for param in model.state_dict().values():
|
||||||
|
param._clear_data()
|
||||||
|
|
||||||
self._verify_parameters("clearance")
|
self._verify_parameters("clearance")
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user