From f677c032c0127c6650b00feae25ebb56a4cf5579 Mon Sep 17 00:00:00 2001 From: zyfncg <1370305206@qq.com> Date: Fri, 29 Aug 2025 16:28:01 +0800 Subject: [PATCH] [CudaGraph] [SOT] Support spliting static graph into piecewise graph with cuda_graph (#3478) * support spliting static graph into piecewise graph with cuda_graph * Update fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix merge conflict * fix bug --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../cudagraph_piecewise_backend.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 59f50eb4c..30a28d293 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -14,11 +14,13 @@ # limitations under the License. """ +from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Dict, Optional import paddle.nn.layer from paddle.device.cuda import graphs +from paddle.jit.dy2static.utils import CUDAGraphState from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import capture_custom_allreduce @@ -48,6 +50,35 @@ class ConcreteSizeEntry: output_buffer: Optional[paddle.Tensor] = None +class Dy2StCudaGraphManager: + def __init__(self): + self.state = CUDAGraphState.DISABLE + self.captured_batch_size = set() + self.batch_size = -1 + + def run_impl(self, original_run_impl, inputs, parameters, attrs): + run_state = self.state + prog_attrs, cuda_graph_attrs = attrs + if run_state == CUDAGraphState.REPLAY: + if self.batch_size not in self.captured_batch_size: + run_state = CUDAGraphState.DISABLE + elif run_state == CUDAGraphState.CAPTURE: + self.captured_batch_size.add(self.batch_size) + + cuda_graph_attrs |= { + "cuda_graph_state": run_state, + "cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0, + } + return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs)) + + @contextmanager + def run_impl_guard(self): + with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard( + self.run_impl, + ): + yield + + class CudaGraphPiecewiseBackend: """Manage the capture and replay of CUDA graphs at the subgraph level.""" @@ -64,6 +95,38 @@ class CudaGraphPiecewiseBackend: self._create_entry_dict() + self.cuda_graph_manager = None + if self.fd_config.graph_opt_config.graph_opt_level > 0: + self.cuda_graph_manager = Dy2StCudaGraphManager() + + def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): + if not entry.captured: + # Warmup the model + for n in range(entry.num_finished_warmup, self.warm_up_size): + entry.num_finished_warmup += 1 + entry.runnable(**kwargs) + logger.debug( + f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, " + f"finished ({n + 1}/{entry.num_finished_warmup}) times" + ) + + # Store input addresses for debug + input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)] + entry.input_addresses = input_addresses + + # Capture + self.cuda_graph_manager.state = CUDAGraphState.CAPTURE + self.cuda_graph_manager.batch_size = entry.real_shape + entry.captured = True + with self.cuda_graph_manager.run_impl_guard(): + entry.runnable(**kwargs) + + # Replay + self.cuda_graph_manager.state = CUDAGraphState.REPLAY + self.cuda_graph_manager.batch_size = entry.real_shape + with self.cuda_graph_manager.run_impl_guard(): + return entry.runnable(**kwargs) + def __call__(self, **kwargs): # Get real shape(all num tokens) ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] @@ -83,6 +146,9 @@ class CudaGraphPiecewiseBackend: if not entry.use_cudagraph: return entry.runnable(**kwargs) + if self.fd_config.graph_opt_config.graph_opt_level > 0: + return self.run_static_model(entry, **kwargs) + # Capture a new cuda graph if entry.cuda_graph is None: # Warmup the model