mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Update dynamic_weight_manager.py
This commit is contained in:
@@ -51,8 +51,7 @@ class DynamicWeightManager:
|
||||
|
||||
logger.info(
|
||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||
f" rank={self.rank}, ranks={self.nranks}, "
|
||||
f" load ipc weight from {self.ipc_path}.")
|
||||
f" rank={self.rank}, ranks={self.nranks}")
|
||||
|
||||
@paddle.no_grad()
|
||||
def _capture_model_state(self):
|
||||
@@ -114,21 +113,25 @@ class DynamicWeightManager:
|
||||
logger.warning(
|
||||
"load model from no_reshard weight, maybe need more GPU memory"
|
||||
)
|
||||
logger.info("IPC snapshot update parameters completed")
|
||||
logger.info(
|
||||
f"IPC snapshot update parameters completed from {model_path}")
|
||||
|
||||
def _update_ipc(self):
|
||||
"""Update using standard IPC strategy (requires Training Worker)."""
|
||||
ipc_meta = paddle.load(self.ipc_path)
|
||||
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
||||
self._update_model_from_state(state_dict, "raw")
|
||||
logger.info("IPC update parameters completed")
|
||||
logger.info(
|
||||
f"IPC update parameters completed from file: {self.ipc_path}")
|
||||
|
||||
def _update_ipc_no_reshard(self):
|
||||
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
|
||||
ipc_meta = paddle.load(self.ipc_path)
|
||||
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
||||
self.models[0].set_state_dict(state_dict)
|
||||
logger.info("IPC no-reshard update parameters completed")
|
||||
logger.info(
|
||||
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
|
||||
)
|
||||
|
||||
def load_model(self) -> nn.Layer:
|
||||
"""Standard model loading without IPC."""
|
||||
@@ -159,6 +162,8 @@ class DynamicWeightManager:
|
||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
|
||||
src_type: str):
|
||||
"""Update model parameters from given state dictionary."""
|
||||
if len(state_dict) == 0:
|
||||
raise ValueError(f"No parameter found in state dict {state_dict}")
|
||||
update_count = 0
|
||||
for name, new_param in state_dict.items():
|
||||
if name not in self.state_dict:
|
||||
|
Reference in New Issue
Block a user