Update dynamic_weight_manager.py

This commit is contained in:
Jiang-Jia-Jun
2025-07-03 15:55:22 +08:00
committed by GitHub
parent 05c670e593
commit 9fd74f75bd

View File

@@ -51,8 +51,7 @@ class DynamicWeightManager:
logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
f" rank={self.rank}, ranks={self.nranks}, "
f" load ipc weight from {self.ipc_path}.")
f" rank={self.rank}, ranks={self.nranks}")
@paddle.no_grad()
def _capture_model_state(self):
@@ -114,21 +113,25 @@ class DynamicWeightManager:
logger.warning(
"load model from no_reshard weight, maybe need more GPU memory"
)
logger.info("IPC snapshot update parameters completed")
logger.info(
f"IPC snapshot update parameters completed from {model_path}")
def _update_ipc(self):
"""Update using standard IPC strategy (requires Training Worker)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self._update_model_from_state(state_dict, "raw")
logger.info("IPC update parameters completed")
logger.info(
f"IPC update parameters completed from file: {self.ipc_path}")
def _update_ipc_no_reshard(self):
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self.models[0].set_state_dict(state_dict)
logger.info("IPC no-reshard update parameters completed")
logger.info(
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
)
def load_model(self) -> nn.Layer:
"""Standard model loading without IPC."""
@@ -159,6 +162,8 @@ class DynamicWeightManager:
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
src_type: str):
"""Update model parameters from given state dictionary."""
if len(state_dict) == 0:
raise ValueError(f"No parameter found in state dict {state_dict}")
update_count = 0
for name, new_param in state_dict.items():
if name not in self.state_dict: