This commit is contained in:
Yuanle Liu
2025-07-16 15:33:10 +08:00
committed by GitHub
parent a83a3eea5f
commit dda4a9f848
10 changed files with 26 additions and 131 deletions

View File

@@ -16,7 +16,7 @@
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict, List
from typing import Any, Dict
import numpy as np
import paddle
@@ -24,9 +24,6 @@ from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.load_weight_utils import \
load_composite_checkpoint
from fastdeploy.model_executor.model_loader import MODEL_CLASSES
class DynamicWeightManager:
@@ -43,11 +40,9 @@ class DynamicWeightManager:
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
self.models: List[nn.Layer] = [model]
self.model: nn.Layer = model
self._capture_model_state()
if self.load_config.load_strategy != "meta":
self.update_parameters()
self.update_parameters()
logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
@@ -56,17 +51,11 @@ class DynamicWeightManager:
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for model in self.models:
for name, param in model.state_dict().items():
logger.debug(
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
)
self.state_dict[name] = param
def add_model(self, model: nn.Layer):
""""add model"""
self.models.append(model)
self._capture_model_state()
for name, param in self.model.state_dict().items():
logger.debug(
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
)
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
@@ -79,8 +68,6 @@ class DynamicWeightManager:
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc,
"ipc_no_reshard": self._update_ipc_no_reshard,
"normal": self.load_model,
}
if handler := strategy_handlers.get(self.load_config.load_strategy):
@@ -106,13 +93,7 @@ class DynamicWeightManager:
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
ipc_state_dict = paddle.load(fallback_path)
try:
self._update_model_from_state(ipc_state_dict, "snapshot")
except Exception:
self.models[0].set_state_dict(ipc_state_dict)
logger.warning(
"load model from no_reshard weight, maybe need more GPU memory"
)
self._update_model_from_state(ipc_state_dict, "snapshot")
logger.info(
f"IPC snapshot update parameters completed from {model_path}")
@@ -124,34 +105,12 @@ class DynamicWeightManager:
logger.info(
f"IPC update parameters completed from file: {self.ipc_path}")
def _update_ipc_no_reshard(self):
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self.models[0].set_state_dict(state_dict)
logger.info(
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
)
def load_model(self) -> nn.Layer:
"""Standard model loading without IPC."""
architectures = self.fd_config.model_config.architectures[0]
model_class = MODEL_CLASSES[architectures]
state_dict = load_composite_checkpoint(
self.fd_config.parallel_config.model_name_or_path,
model_class,
self.fd_config.model_config,
return_numpy=True)
self.models[0].set_state_dict(state_dict)
logger.info("normal load update parameters completed")
def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
paddle.device.cuda.empty_cache()
for model in self.models:
for param in model.state_dict().values():
param._clear_data()
for param in self.model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance")
if self.nranks > 1: