diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index e1ee0ce1f..7be3dca6a 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -71,6 +71,11 @@ class DefaultModelLoader(BaseModelLoader): # register rl model import fastdeploy.rl # noqa + if fd_config.speculative_config.model_type != "mtp": + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + architectures = architectures + "RL" context = paddle.LazyGuard() else: diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 51e80e7b0..9164e61af 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -59,6 +59,11 @@ class DefaultModelLoaderV1(BaseModelLoader): # register rl model import fastdeploy.rl # noqa + if fd_config.speculative_config.model_type != "mtp": + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + architectures = architectures + "RL" with context: diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 7d8a10521..d43ca3017 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -17,11 +17,10 @@ import os import time from multiprocessing.shared_memory import SharedMemory -from typing import Any, Dict +from typing import Any, Dict, List import numpy as np import paddle -from paddle import nn from paddleformers.utils.log import logger from fastdeploy.config import FDConfig @@ -30,7 +29,7 @@ from fastdeploy.config import FDConfig class DynamicWeightManager: """Manages model weights loading, updating and shared state across processes.""" - def __init__(self, fd_config: FDConfig, model: nn.Layer): + def __init__(self, fd_config: FDConfig, models): """Initialize with config and model instances.""" self.fd_config = fd_config self.load_config = fd_config.load_config @@ -41,7 +40,10 @@ class DynamicWeightManager: self.meta_src_id = self._get_gpu_id() self.first_load = True self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}" - self.model: nn.Layer = model + if not isinstance(models, List): + self.model_list = [models] + else: + self.model_list = models self._capture_model_state() self.update_parameters() self.finalize_update() @@ -54,9 +56,10 @@ class DynamicWeightManager: @paddle.no_grad() def _capture_model_state(self): """Capture and store initial model parameters state.""" - for name, param in self.model.state_dict().items(): - logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") - self.state_dict[name] = param + for model in self.model_list: + for name, param in model.state_dict().items(): + logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}") + self.state_dict[name] = param def update_parameters(self, pid: int = 0) -> None: """Core method to update model parameters based on strategy.""" @@ -133,8 +136,9 @@ class DynamicWeightManager: paddle.device.cuda.empty_cache() # step2: release model weight - for param in self.model.state_dict().values(): - param._clear_data() + for model in self.model_list: + for param in model.state_dict().values(): + param._clear_data() self._verify_parameters("clearance")