[Executor] Adjust signal sending order in RL training (#3773)

* Adjust processing order

* fix bug

* fix update_parameters bug

* refine code
This commit is contained in:
RAM
2025-09-10 13:24:20 +08:00
committed by GitHub
parent 453487d5b0
commit d3e4ae3d49
3 changed files with 20 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional
import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer import paddle.nn.layer
from paddle.device.cuda import graphs from paddle.device.cuda import graphs
@@ -51,27 +52,24 @@ class ConcreteSizeEntry:
class Dy2StCudaGraphManager: class Dy2StCudaGraphManager:
def __init__(self): def __init__(self):
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
from paddle.jit.dy2static.utils import CUDAGraphState
self.state = CUDAGraphState.DISABLE self.state = jit_utils.CUDAGraphState.DISABLE
self.captured_batch_size = set() self.captured_batch_size = set()
self.batch_size = -1 self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs): def run_impl(self, original_run_impl, inputs, parameters, attrs):
from paddle.jit.dy2static.utils import CUDAGraphState
run_state = self.state run_state = self.state
prog_attrs, cuda_graph_attrs = attrs prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY: if run_state == jit_utils.CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size: if self.batch_size not in self.captured_batch_size:
run_state = CUDAGraphState.DISABLE run_state = jit_utils.CUDAGraphState.DISABLE
elif run_state == CUDAGraphState.CAPTURE: elif run_state == jit_utils.CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size) self.captured_batch_size.add(self.batch_size)
cuda_graph_attrs |= { cuda_graph_attrs |= {
"cuda_graph_state": run_state, "cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0, "cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
} }
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs)) return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
@@ -104,7 +102,6 @@ class CudaGraphPiecewiseBackend:
self.cuda_graph_manager = Dy2StCudaGraphManager() self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
from paddle.jit.dy2static.utils import CUDAGraphState
if not entry.captured: if not entry.captured:
# Warmup the model # Warmup the model
@@ -121,14 +118,14 @@ class CudaGraphPiecewiseBackend:
entry.input_addresses = input_addresses entry.input_addresses = input_addresses
# Capture # Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True entry.captured = True
with self.cuda_graph_manager.run_impl_guard(): with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs) entry.runnable(**kwargs)
# Replay # Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.real_shape self.cuda_graph_manager.batch_size = entry.real_shape
with self.cuda_graph_manager.run_impl_guard(): with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs) return entry.runnable(**kwargs)

View File

@@ -44,6 +44,7 @@ class DynamicWeightManager:
self.model: nn.Layer = model self.model: nn.Layer = model
self._capture_model_state() self._capture_model_state()
self.update_parameters() self.update_parameters()
self.finalize_update()
logger.info( logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
@@ -79,8 +80,6 @@ class DynamicWeightManager:
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s") logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
self._finalize_update(pid)
def _update_ipc_snapshot(self): def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery.""" """Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join( model_path = os.path.join(
@@ -143,7 +142,7 @@ class DynamicWeightManager:
if src.shape != dst.shape: if src.shape != dst.shape:
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
def _finalize_update(self, pid: int): def finalize_update(self, pid: int = 0):
"""Finalize update process with verification.""" """Finalize update process with verification."""
self._verify_parameters("update") self._verify_parameters("update")
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:

View File

@@ -1885,25 +1885,27 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_meta.clear_caches() self.forward_meta.clear_caches()
def clear_parameters(self, pid): def clear_parameters(self, pid):
""" " Dynamic model loader use to clear parameters use for RL""" """Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(pid) self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache() self.clear_cache()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# Clear CudaGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid): def update_parameters(self, pid):
""" " Dynamic model loader use to update parameters use for RL""" """Dynamic model loader use to update parameters use for RL"""
# Update parameters
self.dynamic_weight_manager.update_parameters(pid) self.dynamic_weight_manager.update_parameters(pid)
self.initialize_kv_cache() self.initialize_kv_cache()
# Recapture CUDAGraph
# Recapture CudaGraph
if self.use_cudagraph: if self.use_cudagraph:
self.capture_model() self.capture_model()
# Send single
self.dynamic_weight_manager.finalize_update(pid)
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")