support cuda graph (#4056)

* support cuda graph

* upstate
This commit is contained in:
Yuanle Liu
2025-09-11 11:38:32 +08:00
committed by GitHub
parent 749f074e44
commit 48f2ab3fb3

View File

@@ -15,12 +15,13 @@
"""
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
import paddle.nn.layer
from paddle.device.cuda import graphs
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
from fastdeploy.utils import get_logger
@@ -45,8 +46,8 @@ class ConcreteSizeEntry:
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
# Output buffers of cudagraph
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
class Dy2StCudaGraphManager:
@@ -135,7 +136,7 @@ class CudaGraphPiecewiseBackend:
def __call__(self, **kwargs):
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
@@ -176,14 +177,22 @@ class CudaGraphPiecewiseBackend:
# Capture
with capture_custom_allreduce():
new_grpah.capture_begin()
output = entry.runnable(**kwargs)
outputs = entry.runnable(**kwargs)
if isinstance(outputs, paddle.Tensor):
assert outputs is not None
outputs = [outputs]
new_grpah.capture_end()
# Store output buffer
entry.cuda_graph = new_grpah
entry.output_buffer = paddle.zeros_like(output)
output._share_buffer_to(entry.output_buffer)
output._clear
for output in outputs:
if output is not None:
output_buffer = paddle.zeros_like(output)
output._share_buffer_to(output_buffer)
output._clear
entry.output_buffers.append(output_buffer)
else:
entry.output_buffers.append(None)
paddle.device.synchronize()
@@ -194,7 +203,9 @@ class CudaGraphPiecewiseBackend:
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer
if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
def _create_entry_dict(self):
""" """
@@ -224,8 +235,9 @@ class CudaGraphPiecewiseBackend:
def _save_cudagrpah_dot_files(self, entry):
"""Print CUDAGrpah to dot files"""
log_dir = envs.FD_LOG_DIR
if entry.cuda_graph:
entry.cuda_graph.print_to_dot_files(
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
1 << 0,
)