[CudaGraph] [SOT] Support spliting static graph into piecewise graph with cuda_graph (#3478)

* support spliting static graph into piecewise graph with cuda_graph

* Update fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix merge conflict

* fix bug

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
zyfncg
2025-08-29 16:28:01 +08:00
committed by GitHub
parent 48d760539b
commit f677c032c0

View File

@@ -14,11 +14,13 @@
# limitations under the License.
"""
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional
import paddle.nn.layer
from paddle.device.cuda import graphs
from paddle.jit.dy2static.utils import CUDAGraphState
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
@@ -48,6 +50,35 @@ class ConcreteSizeEntry:
output_buffer: Optional[paddle.Tensor] = None
class Dy2StCudaGraphManager:
def __init__(self):
self.state = CUDAGraphState.DISABLE
self.captured_batch_size = set()
self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs):
run_state = self.state
prog_attrs, cuda_graph_attrs = attrs
if run_state == CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size:
run_state = CUDAGraphState.DISABLE
elif run_state == CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size)
cuda_graph_attrs |= {
"cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
}
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
@contextmanager
def run_impl_guard(self):
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
self.run_impl,
):
yield
class CudaGraphPiecewiseBackend:
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
@@ -64,6 +95,38 @@ class CudaGraphPiecewiseBackend:
self._create_entry_dict()
self.cuda_graph_manager = None
if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses
# Capture
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
# Replay
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.real_shape
with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs)
def __call__(self, **kwargs):
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
@@ -83,6 +146,9 @@ class CudaGraphPiecewiseBackend:
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
if self.fd_config.graph_opt_config.graph_opt_level > 0:
return self.run_static_model(entry, **kwargs)
# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model