mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Update dynamic_weight_manager.py
This commit is contained in:
@@ -51,8 +51,7 @@ class DynamicWeightManager:
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||||
f" rank={self.rank}, ranks={self.nranks}, "
|
f" rank={self.rank}, ranks={self.nranks}")
|
||||||
f" load ipc weight from {self.ipc_path}.")
|
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def _capture_model_state(self):
|
def _capture_model_state(self):
|
||||||
@@ -114,21 +113,25 @@ class DynamicWeightManager:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"load model from no_reshard weight, maybe need more GPU memory"
|
"load model from no_reshard weight, maybe need more GPU memory"
|
||||||
)
|
)
|
||||||
logger.info("IPC snapshot update parameters completed")
|
logger.info(
|
||||||
|
f"IPC snapshot update parameters completed from {model_path}")
|
||||||
|
|
||||||
def _update_ipc(self):
|
def _update_ipc(self):
|
||||||
"""Update using standard IPC strategy (requires Training Worker)."""
|
"""Update using standard IPC strategy (requires Training Worker)."""
|
||||||
ipc_meta = paddle.load(self.ipc_path)
|
ipc_meta = paddle.load(self.ipc_path)
|
||||||
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
||||||
self._update_model_from_state(state_dict, "raw")
|
self._update_model_from_state(state_dict, "raw")
|
||||||
logger.info("IPC update parameters completed")
|
logger.info(
|
||||||
|
f"IPC update parameters completed from file: {self.ipc_path}")
|
||||||
|
|
||||||
def _update_ipc_no_reshard(self):
|
def _update_ipc_no_reshard(self):
|
||||||
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
|
"""Update using no-reshard IPC strategy (faster but uses more memory)."""
|
||||||
ipc_meta = paddle.load(self.ipc_path)
|
ipc_meta = paddle.load(self.ipc_path)
|
||||||
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
|
||||||
self.models[0].set_state_dict(state_dict)
|
self.models[0].set_state_dict(state_dict)
|
||||||
logger.info("IPC no-reshard update parameters completed")
|
logger.info(
|
||||||
|
f"IPC no-reshard update parameters completed from file: {self.ipc_path}"
|
||||||
|
)
|
||||||
|
|
||||||
def load_model(self) -> nn.Layer:
|
def load_model(self) -> nn.Layer:
|
||||||
"""Standard model loading without IPC."""
|
"""Standard model loading without IPC."""
|
||||||
@@ -159,6 +162,8 @@ class DynamicWeightManager:
|
|||||||
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
|
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor],
|
||||||
src_type: str):
|
src_type: str):
|
||||||
"""Update model parameters from given state dictionary."""
|
"""Update model parameters from given state dictionary."""
|
||||||
|
if len(state_dict) == 0:
|
||||||
|
raise ValueError(f"No parameter found in state dict {state_dict}")
|
||||||
update_count = 0
|
update_count = 0
|
||||||
for name, new_param in state_dict.items():
|
for name, new_param in state_dict.items():
|
||||||
if name not in self.state_dict:
|
if name not in self.state_dict:
|
||||||
|
Reference in New Issue
Block a user