support cuda graph (#4056)

* support cuda graph

* upstate
This commit is contained in:
Yuanle Liu
2025-09-11 11:38:32 +08:00
committed by GitHub
parent 749f074e44
commit 48f2ab3fb3

View File

@@ -15,12 +15,13 @@
""" """
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Callable, Dict, Optional from typing import Callable, Dict, List, Optional
import paddle.nn.layer import paddle.nn.layer
from paddle.device.cuda import graphs from paddle.device.cuda import graphs
from fastdeploy import envs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce from fastdeploy.distributed.communication import capture_custom_allreduce
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
@@ -45,8 +46,8 @@ class ConcreteSizeEntry:
num_finished_warmup: int = 0 num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current real shape # Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph # Output buffers of cudagraph
output_buffer: Optional[paddle.Tensor] = None output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
class Dy2StCudaGraphManager: class Dy2StCudaGraphManager:
@@ -135,7 +136,7 @@ class CudaGraphPiecewiseBackend:
def __call__(self, **kwargs): def __call__(self, **kwargs):
# Get real shape(all num tokens) # Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0] real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape] padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug( logger.debug(
@@ -176,14 +177,22 @@ class CudaGraphPiecewiseBackend:
# Capture # Capture
with capture_custom_allreduce(): with capture_custom_allreduce():
new_grpah.capture_begin() new_grpah.capture_begin()
output = entry.runnable(**kwargs) outputs = entry.runnable(**kwargs)
if isinstance(outputs, paddle.Tensor):
assert outputs is not None
outputs = [outputs]
new_grpah.capture_end() new_grpah.capture_end()
# Store output buffer # Store output buffer
entry.cuda_graph = new_grpah entry.cuda_graph = new_grpah
entry.output_buffer = paddle.zeros_like(output) for output in outputs:
output._share_buffer_to(entry.output_buffer) if output is not None:
output._clear output_buffer = paddle.zeros_like(output)
output._share_buffer_to(output_buffer)
output._clear
entry.output_buffers.append(output_buffer)
else:
entry.output_buffers.append(None)
paddle.device.synchronize() paddle.device.synchronize()
@@ -194,7 +203,9 @@ class CudaGraphPiecewiseBackend:
# Replay # Replay
entry.cuda_graph.replay() entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
def _create_entry_dict(self): def _create_entry_dict(self):
""" """ """ """
@@ -224,8 +235,9 @@ class CudaGraphPiecewiseBackend:
def _save_cudagrpah_dot_files(self, entry): def _save_cudagrpah_dot_files(self, entry):
"""Print CUDAGrpah to dot files""" """Print CUDAGrpah to dot files"""
log_dir = envs.FD_LOG_DIR
if entry.cuda_graph: if entry.cuda_graph:
entry.cuda_graph.print_to_dot_files( entry.cuda_graph.print_to_dot_files(
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
1 << 0, 1 << 0,
) )