[Executor] Adjust signal sending order in RL training (#3773) (#4066)

* Adjust processing order

* fix bug

* fix update_parameters bug

* refine code
This commit is contained in:
RAM
2025-09-11 15:41:32 +08:00
committed by GitHub
parent 48f2ab3fb3
commit 63d24b2210
3 changed files with 20 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer import paddle.nn.layer
from paddle.device.cuda import graphs from paddle.device.cuda import graphs
@@ -52,27 +53,24 @@ class ConcreteSizeEntry:
class Dy2StCudaGraphManager: class Dy2StCudaGraphManager:
def __init__(self): def __init__(self):
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
from paddle.jit.dy2static.utils import CUDAGraphState
self.state = CUDAGraphState.DISABLE self.state = jit_utils.CUDAGraphState.DISABLE
self.captured_batch_size = set() self.captured_batch_size = set()
self.batch_size = -1 self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs): def run_impl(self, original_run_impl, inputs, parameters, attrs):
from paddle.jit.dy2static.utils import CUDAGraphState
run_state = self.state run_state = self.state
prog_attrs, cuda_graph_attrs = attrs prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY: if run_state == jit_utils.CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size: if self.batch_size not in self.captured_batch_size:
run_state = CUDAGraphState.DISABLE run_state = jit_utils.CUDAGraphState.DISABLE
elif run_state == CUDAGraphState.CAPTURE: elif run_state == jit_utils.CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size) self.captured_batch_size.add(self.batch_size)
cuda_graph_attrs |= { cuda_graph_attrs |= {
"cuda_graph_state": run_state, "cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0, "cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
} }
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs)) return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
@@ -105,7 +103,6 @@ class CudaGraphPiecewiseBackend:
self.cuda_graph_manager = Dy2StCudaGraphManager() self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
from paddle.jit.dy2static.utils import CUDAGraphState
if not entry.captured: if not entry.captured:
# Warmup the model # Warmup the model
@@ -122,14 +119,14 @@ class CudaGraphPiecewiseBackend:
entry.input_addresses = input_addresses entry.input_addresses = input_addresses
# Capture # Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True entry.captured = True
with self.cuda_graph_manager.run_impl_guard(): with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs) entry.runnable(**kwargs)
# Replay # Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.real_shape self.cuda_graph_manager.batch_size = entry.real_shape
with self.cuda_graph_manager.run_impl_guard(): with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs) return entry.runnable(**kwargs)

View File

@@ -44,6 +44,7 @@ class DynamicWeightManager:
self.model: nn.Layer = model self.model: nn.Layer = model
self._capture_model_state() self._capture_model_state()
self.update_parameters() self.update_parameters()
self.finalize_update()
logger.info( logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
@@ -79,8 +80,6 @@ class DynamicWeightManager:
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s") logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
self._finalize_update(pid)
def _update_ipc_snapshot(self): def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery.""" """Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join( model_path = os.path.join(
@@ -144,7 +143,7 @@ class DynamicWeightManager:
if src.shape != dst.shape: if src.shape != dst.shape:
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
def _finalize_update(self, pid: int): def finalize_update(self, pid: int = 0):
"""Finalize update process with verification.""" """Finalize update process with verification."""
self._verify_parameters("update") self._verify_parameters("update")
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:

View File

@@ -1705,25 +1705,27 @@ class GPUModelRunner(ModelRunnerBase):
self.forward_meta.clear_caches() self.forward_meta.clear_caches()
def clear_parameters(self, pid): def clear_parameters(self, pid):
""" " Dynamic model loader use to clear parameters use for RL""" """Dynamic model loader use to clear parameters use for RL"""
# Clear CUDAGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
# Clear parameters and Send single
self.dynamic_weight_manager.clear_parameters(pid) self.dynamic_weight_manager.clear_parameters(pid)
self.clear_cache() self.clear_cache()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
# Clear CudaGraph
if self.use_cudagraph:
self.model.clear_grpah_opt_backend()
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
def update_parameters(self, pid): def update_parameters(self, pid):
""" " Dynamic model loader use to update parameters use for RL""" """Dynamic model loader use to update parameters use for RL"""
# Update parameters
self.dynamic_weight_manager.update_parameters(pid) self.dynamic_weight_manager.update_parameters(pid)
self.initialize_kv_cache() self.initialize_kv_cache()
# Recapture CUDAGraph
# Recapture CudaGraph
if self.use_cudagraph: if self.use_cudagraph:
self.capture_model() self.capture_model()
# Send single
self.dynamic_weight_manager.finalize_update(pid)
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory") self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")