diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 19709f5db..8bc73d701 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -18,6 +18,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional +import paddle.jit.dy2static.utils as jit_utils import paddle.nn.layer from paddle.device.cuda import graphs @@ -52,27 +53,24 @@ class ConcreteSizeEntry: class Dy2StCudaGraphManager: def __init__(self): - # NOTE(gongshaotian): Use local import to avoid RLHF version problems - from paddle.jit.dy2static.utils import CUDAGraphState - self.state = CUDAGraphState.DISABLE + self.state = jit_utils.CUDAGraphState.DISABLE self.captured_batch_size = set() self.batch_size = -1 def run_impl(self, original_run_impl, inputs, parameters, attrs): - from paddle.jit.dy2static.utils import CUDAGraphState run_state = self.state prog_attrs, cuda_graph_attrs = attrs - if run_state == CUDAGraphState.REPLAY: + if run_state == jit_utils.CUDAGraphState.REPLAY: if self.batch_size not in self.captured_batch_size: - run_state = CUDAGraphState.DISABLE - elif run_state == CUDAGraphState.CAPTURE: + run_state = jit_utils.CUDAGraphState.DISABLE + elif run_state == jit_utils.CUDAGraphState.CAPTURE: self.captured_batch_size.add(self.batch_size) cuda_graph_attrs |= { "cuda_graph_state": run_state, - "cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0, + "cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0, } return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs)) @@ -105,7 +103,6 @@ class CudaGraphPiecewiseBackend: self.cuda_graph_manager = Dy2StCudaGraphManager() def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): - from paddle.jit.dy2static.utils import CUDAGraphState if not entry.captured: # Warmup the model @@ -122,14 +119,14 @@ class CudaGraphPiecewiseBackend: entry.input_addresses = input_addresses # Capture - self.cuda_graph_manager.state = CUDAGraphState.CAPTURE + self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE self.cuda_graph_manager.batch_size = entry.real_shape entry.captured = True with self.cuda_graph_manager.run_impl_guard(): entry.runnable(**kwargs) # Replay - self.cuda_graph_manager.state = CUDAGraphState.REPLAY + self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY self.cuda_graph_manager.batch_size = entry.real_shape with self.cuda_graph_manager.run_impl_guard(): return entry.runnable(**kwargs) diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index 32459d0a4..ce50c1b4d 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -44,6 +44,7 @@ class DynamicWeightManager: self.model: nn.Layer = model self._capture_model_state() self.update_parameters() + self.finalize_update() logger.info( f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, " @@ -79,8 +80,6 @@ class DynamicWeightManager: logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s") - self._finalize_update(pid) - def _update_ipc_snapshot(self): """Update using IPC snapshot strategy for elastic recovery.""" model_path = os.path.join( @@ -144,7 +143,7 @@ class DynamicWeightManager: if src.shape != dst.shape: raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}") - def _finalize_update(self, pid: int): + def finalize_update(self, pid: int = 0): """Finalize update process with verification.""" self._verify_parameters("update") if self.parallel_config.tensor_parallel_size > 1: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 385d0c5ab..6da0fb5b9 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1705,25 +1705,27 @@ class GPUModelRunner(ModelRunnerBase): self.forward_meta.clear_caches() def clear_parameters(self, pid): - """ " Dynamic model loader use to clear parameters use for RL""" + """Dynamic model loader use to clear parameters use for RL""" + # Clear CUDAGraph + if self.use_cudagraph: + self.model.clear_grpah_opt_backend() + # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters(pid) self.clear_cache() paddle.device.cuda.empty_cache() - # Clear CudaGraph - if self.use_cudagraph: - self.model.clear_grpah_opt_backend() - self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory") def update_parameters(self, pid): - """ " Dynamic model loader use to update parameters use for RL""" + """Dynamic model loader use to update parameters use for RL""" + # Update parameters self.dynamic_weight_manager.update_parameters(pid) self.initialize_kv_cache() - - # Recapture CudaGraph + # Recapture CUDAGraph if self.use_cudagraph: self.capture_model() + # Send single + self.dynamic_weight_manager.finalize_update(pid) self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")