mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 21:32:29 +08:00
[Executor] Adjust signal sending order in RL training (#3773) (#4066) (#4178)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Adjust processing order * fix bug * fix update_parameters bug * refine code
This commit is contained in:
@@ -45,6 +45,7 @@ class DynamicWeightManager:
|
||||
self.model: nn.Layer = model
|
||||
self._capture_model_state()
|
||||
self.update_parameters()
|
||||
self.finalize_update()
|
||||
|
||||
logger.info(
|
||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||
@@ -81,8 +82,6 @@ class DynamicWeightManager:
|
||||
|
||||
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
self._finalize_update(pid)
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||
model_path = os.path.join(
|
||||
@@ -146,7 +145,7 @@ class DynamicWeightManager:
|
||||
if src.shape != dst.shape:
|
||||
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
|
||||
|
||||
def _finalize_update(self, pid: int):
|
||||
def finalize_update(self, pid: int = 0):
|
||||
"""Finalize update process with verification."""
|
||||
self._verify_parameters("update")
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
|
Reference in New Issue
Block a user