mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Executor] Adjust signal sending order in RL training (#3773) (#4066) (#4178)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* Adjust processing order * fix bug * fix update_parameters bug * refine code
This commit is contained in:
@@ -18,6 +18,7 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import paddle.jit.dy2static.utils as jit_utils
|
||||
import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
@@ -51,27 +52,24 @@ class ConcreteSizeEntry:
|
||||
|
||||
class Dy2StCudaGraphManager:
|
||||
def __init__(self):
|
||||
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
self.state = CUDAGraphState.DISABLE
|
||||
self.state = jit_utils.CUDAGraphState.DISABLE
|
||||
self.captured_batch_size = set()
|
||||
self.batch_size = -1
|
||||
|
||||
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
run_state = self.state
|
||||
prog_attrs, cuda_graph_attrs = attrs
|
||||
if run_state == CUDAGraphState.REPLAY:
|
||||
if run_state == jit_utils.CUDAGraphState.REPLAY:
|
||||
if self.batch_size not in self.captured_batch_size:
|
||||
run_state = CUDAGraphState.DISABLE
|
||||
elif run_state == CUDAGraphState.CAPTURE:
|
||||
run_state = jit_utils.CUDAGraphState.DISABLE
|
||||
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
|
||||
self.captured_batch_size.add(self.batch_size)
|
||||
|
||||
cuda_graph_attrs |= {
|
||||
"cuda_graph_state": run_state,
|
||||
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
|
||||
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
|
||||
}
|
||||
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
|
||||
|
||||
@@ -104,7 +102,6 @@ class CudaGraphPiecewiseBackend:
|
||||
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||
|
||||
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||
|
||||
if not entry.captured:
|
||||
# Warmup the model
|
||||
@@ -121,14 +118,14 @@ class CudaGraphPiecewiseBackend:
|
||||
entry.input_addresses = input_addresses
|
||||
|
||||
# Capture
|
||||
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
|
||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||
entry.captured = True
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
entry.runnable(**kwargs)
|
||||
|
||||
# Replay
|
||||
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
|
||||
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
|
||||
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||
with self.cuda_graph_manager.run_impl_guard():
|
||||
return entry.runnable(**kwargs)
|
||||
|
@@ -45,6 +45,7 @@ class DynamicWeightManager:
|
||||
self.model: nn.Layer = model
|
||||
self._capture_model_state()
|
||||
self.update_parameters()
|
||||
self.finalize_update()
|
||||
|
||||
logger.info(
|
||||
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
|
||||
@@ -81,8 +82,6 @@ class DynamicWeightManager:
|
||||
|
||||
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
|
||||
|
||||
self._finalize_update(pid)
|
||||
|
||||
def _update_ipc_snapshot(self):
|
||||
"""Update using IPC snapshot strategy for elastic recovery."""
|
||||
model_path = os.path.join(
|
||||
@@ -146,7 +145,7 @@ class DynamicWeightManager:
|
||||
if src.shape != dst.shape:
|
||||
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
|
||||
|
||||
def _finalize_update(self, pid: int):
|
||||
def finalize_update(self, pid: int = 0):
|
||||
"""Finalize update process with verification."""
|
||||
self._verify_parameters("update")
|
||||
if self.parallel_config.tensor_parallel_size > 1:
|
||||
|
@@ -1705,25 +1705,27 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
def clear_parameters(self, pid):
|
||||
""" " Dynamic model loader use to clear parameters use for RL"""
|
||||
"""Dynamic model loader use to clear parameters use for RL"""
|
||||
# Clear CUDAGraph
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
# Clear parameters and Send single
|
||||
self.dynamic_weight_manager.clear_parameters(pid)
|
||||
self.clear_cache()
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
# Clear CudaGraph
|
||||
if self.use_cudagraph:
|
||||
self.model.clear_grpah_opt_backend()
|
||||
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
|
||||
|
||||
def update_parameters(self, pid):
|
||||
""" " Dynamic model loader use to update parameters use for RL"""
|
||||
"""Dynamic model loader use to update parameters use for RL"""
|
||||
# Update parameters
|
||||
self.dynamic_weight_manager.update_parameters(pid)
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Recapture CudaGraph
|
||||
# Recapture CUDAGraph
|
||||
if self.use_cudagraph:
|
||||
self.capture_model()
|
||||
# Send single
|
||||
self.dynamic_weight_manager.finalize_update(pid)
|
||||
|
||||
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
|
||||
|
||||
|
Reference in New Issue
Block a user