[CUDAGraph] Support multi output buffers and merge some fixes from feature/exp_0908 (#4062)

* refine cudagraph

* refine cudagraph

* typo

* fix

* fix plugins

* fix

* update

* update

* update
This commit is contained in:
Yuanle Liu
2025-09-15 16:21:30 +08:00
committed by GitHub
parent 9409665713
commit b1b33211e8
8 changed files with 70 additions and 45 deletions

View File

@@ -14,14 +14,16 @@
# limitations under the License.
"""
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, Optional
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer
from paddle.device.cuda import graphs
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
from fastdeploy.utils import get_logger
@@ -46,8 +48,8 @@ class ConcreteSizeEntry:
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
# Output buffers of cudagraph
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
class Dy2StCudaGraphManager:
@@ -130,9 +132,9 @@ class CudaGraphPiecewiseBackend:
with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs)
def __call__(self, **kwargs):
def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
@@ -173,14 +175,22 @@ class CudaGraphPiecewiseBackend:
# Capture
with capture_custom_allreduce():
new_grpah.capture_begin()
output = entry.runnable(**kwargs)
outputs = entry.runnable(**kwargs)
if isinstance(outputs, paddle.Tensor):
assert outputs is not None
outputs = [outputs]
new_grpah.capture_end()
# Store output buffer
entry.cuda_graph = new_grpah
entry.output_buffer = paddle.zeros_like(output)
output._share_buffer_to(entry.output_buffer)
output._clear
for output in outputs:
if output is not None:
output_buffer = paddle.zeros_like(output)
output._share_buffer_to(output_buffer)
output._clear
entry.output_buffers.append(output_buffer)
else:
entry.output_buffers.append(None)
paddle.device.synchronize()
@@ -191,7 +201,9 @@ class CudaGraphPiecewiseBackend:
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer
if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
def _create_entry_dict(self):
""" """
@@ -221,8 +233,11 @@ class CudaGraphPiecewiseBackend:
def _save_cudagrpah_dot_files(self, entry):
"""Print CUDAGrpah to dot files"""
log_dir = envs.FD_LOG_DIR
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
if entry.cuda_graph:
entry.cuda_graph.print_to_dot_files(
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
f"{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
1 << 0,
)