mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
rl update (#2861)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import os
|
||||
import time
|
||||
from multiprocessing.shared_memory import SharedMemory
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -24,9 +24,6 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.load_weight_utils import \
|
||||
load_composite_checkpoint
|
||||
from fastdeploy.model_executor.model_loader import MODEL_CLASSES
|
||||
|
||||
|
||||
class DynamicWeightManager:
|
||||
@@ -43,11 +40,9 @@ class DynamicWeightManager:
|
||||
self.meta_src_id = self._get_gpu_id()
|
||||
self.first_load = True
|
||||
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
|
||||
self.models: List[nn.Layer] = [model]
|
||||
self.model: nn.Layer = model
|
||||
self._capture_model_state()
|
||||
|
||||
if self.load_config.load_strategy != "meta":
|
||||
self.update_parameters()
|
||||
self.update_parameters()
|
||||
|
||||
logger.info(
|
||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||
@@ -56,17 +51,11 @@ class DynamicWeightManager:
|
||||
@paddle.no_grad()
|
||||
def _capture_model_state(self):
|
||||
"""Capture and store initial model parameters state."""
|
||||
for model in self.models:
|
||||
for name, param in model.state_dict().items():
|
||||
logger.debug(
|
||||
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
|
||||
)
|
||||
self.state_dict[name] = param
|
||||
|
||||
def add_model(self, model: nn.Layer):
|
||||
""""add model"""
|
||||
self.models.append(model)
|
||||
self._capture_model_state()
|
||||
for name, param in self.model.state_dict().items():
|
||||
logger.debug(
|
||||
f"Model param: {name}, shape={param.shape}, dtype={param.dtype}"
|
||||
)
|
||||
self.state_dict[name] = param
|
||||
|
||||
def update_parameters(self, pid: int = 0) -> None:
|
||||
"""Core method to update model parameters based on strategy."""
|
||||
@@ -79,8 +68,6 @@ class DynamicWeightManager:
|
||||
strategy_handlers = {
|
||||
"ipc_snapshot": self._update_ipc_snapshot,
|
||||
"ipc": self._update_ipc,
|
||||
"ipc_no_reshard": self._update_ipc_no_reshard,
|
||||
"normal": self.load_model,
|
||||
}
|
||||
|
||||
if handler := strategy_handlers.get(self.load_config.load_strategy):
|
||||
@@ -106,13 +93,7 @@ class DynamicWeightManager:
|
||||
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
|
||||
ipc_state_dict = paddle.load(fallback_path)
|
||||
|
||||
try:
|
||||
self._update_model_from_state(ipc_state_dict, "snapshot")
|
||||
except Exception:
|
||||
self.models[0].set_state_dict(ipc_state_dict)
|
||||
logger.warning(
|
||||
"load model from no_reshard weight, maybe need more GPU memory"
|
||||
)
|
||||
self._update_model_from_state(ipc_state_dict, "snapshot")
|
||||
logger.info(
|
||||
f"IPC snapshot update parameters completed from {model_path}")
|
||||
|
||||
@@ -124,34 +105,12 @@ class DynamicWeightManager:
|
||||
logger.info(
|
||||
f"IPC update parameters completed from file: {self.ipc_path}")
|
||||
|
||||
def _update_ipc_no_reshard(self):
|
||||
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
|
||||
ipc_meta = paddle.load(self.ipc_path)
|
||||
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
||||
self.models[0].set_state_dict(state_dict)
|
||||
logger.info(
|
||||
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
|
||||
)
|
||||
|
||||
def load_model(self) -> nn.Layer:
|
||||
"""Standard model loading without IPC."""
|
||||
architectures = self.fd_config.model_config.architectures[0]
|
||||
model_class = MODEL_CLASSES[architectures]
|
||||
state_dict = load_composite_checkpoint(
|
||||
self.fd_config.parallel_config.model_name_or_path,
|
||||
model_class,
|
||||
self.fd_config.model_config,
|
||||
return_numpy=True)
|
||||
self.models[0].set_state_dict(state_dict)
|
||||
logger.info("normal load update parameters completed")
|
||||
|
||||
def clear_parameters(self, pid: int = 0) -> None:
|
||||
"""Clear all model parameters and free memory."""
|
||||
logger.info("start clear paramaters")
|
||||
paddle.device.cuda.empty_cache()
|
||||
for model in self.models:
|
||||
for param in model.state_dict().values():
|
||||
param._clear_data()
|
||||
for param in self.model.state_dict().values():
|
||||
param._clear_data()
|
||||
|
||||
self._verify_parameters("clearance")
|
||||
if self.nranks > 1:
|
||||
|
Reference in New Issue
Block a user