[Executor] Adjust signal sending order in RL training (#3773)

* Adjust processing order

* fix bug

* fix update_parameters bug

* refine code
This commit is contained in:
RAM
2025-09-10 13:24:20 +08:00
committed by GitHub
parent 453487d5b0
commit d3e4ae3d49
3 changed files with 20 additions and 22 deletions

View File

@@ -18,6 +18,7 @@ from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional
import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer
from paddle.device.cuda import graphs
@@ -51,27 +52,24 @@ class ConcreteSizeEntry:
class Dy2StCudaGraphManager:
def __init__(self):
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
from paddle.jit.dy2static.utils import CUDAGraphState
self.state = CUDAGraphState.DISABLE
self.state = jit_utils.CUDAGraphState.DISABLE
self.captured_batch_size = set()
self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs):
from paddle.jit.dy2static.utils import CUDAGraphState
run_state = self.state
prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY:
if run_state == jit_utils.CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size:
run_state = CUDAGraphState.DISABLE
elif run_state == CUDAGraphState.CAPTURE:
run_state = jit_utils.CUDAGraphState.DISABLE
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size)
cuda_graph_attrs |= {
"cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
}
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
@@ -104,7 +102,6 @@ class CudaGraphPiecewiseBackend:
self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
from paddle.jit.dy2static.utils import CUDAGraphState
if not entry.captured:
# Warmup the model
@@ -121,14 +118,14 @@ class CudaGraphPiecewiseBackend:
entry.input_addresses = input_addresses
# Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
# Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.real_shape
with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs)