mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
@@ -15,12 +15,13 @@
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import paddle.nn.layer
|
||||
from paddle.device.cuda import graphs
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||
from fastdeploy.utils import get_logger
|
||||
@@ -45,8 +46,8 @@ class ConcreteSizeEntry:
|
||||
num_finished_warmup: int = 0
|
||||
# Captured cuda graph object corresponding to the current real shape
|
||||
cuda_graph: Optional[graphs.CUDAGraph] = None
|
||||
# Output buffer of cudagraph
|
||||
output_buffer: Optional[paddle.Tensor] = None
|
||||
# Output buffers of cudagraph
|
||||
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
class Dy2StCudaGraphManager:
|
||||
@@ -135,7 +136,7 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
# Get real shape(all num tokens)
|
||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
|
||||
real_shape = ids_remove_padding.shape[0]
|
||||
padding_real_shape = self.real_shape_to_captured_size[real_shape]
|
||||
logger.debug(
|
||||
@@ -176,14 +177,22 @@ class CudaGraphPiecewiseBackend:
|
||||
# Capture
|
||||
with capture_custom_allreduce():
|
||||
new_grpah.capture_begin()
|
||||
output = entry.runnable(**kwargs)
|
||||
outputs = entry.runnable(**kwargs)
|
||||
if isinstance(outputs, paddle.Tensor):
|
||||
assert outputs is not None
|
||||
outputs = [outputs]
|
||||
new_grpah.capture_end()
|
||||
|
||||
# Store output buffer
|
||||
entry.cuda_graph = new_grpah
|
||||
entry.output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(entry.output_buffer)
|
||||
output._clear
|
||||
for output in outputs:
|
||||
if output is not None:
|
||||
output_buffer = paddle.zeros_like(output)
|
||||
output._share_buffer_to(output_buffer)
|
||||
output._clear
|
||||
entry.output_buffers.append(output_buffer)
|
||||
else:
|
||||
entry.output_buffers.append(None)
|
||||
|
||||
paddle.device.synchronize()
|
||||
|
||||
@@ -194,7 +203,9 @@ class CudaGraphPiecewiseBackend:
|
||||
# Replay
|
||||
entry.cuda_graph.replay()
|
||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
|
||||
return entry.output_buffer
|
||||
if len(entry.output_buffers) == 1:
|
||||
return entry.output_buffers[0]
|
||||
return entry.output_buffers
|
||||
|
||||
def _create_entry_dict(self):
|
||||
""" """
|
||||
@@ -224,8 +235,9 @@ class CudaGraphPiecewiseBackend:
|
||||
|
||||
def _save_cudagrpah_dot_files(self, entry):
|
||||
"""Print CUDAGrpah to dot files"""
|
||||
log_dir = envs.FD_LOG_DIR
|
||||
if entry.cuda_graph:
|
||||
entry.cuda_graph.print_to_dot_files(
|
||||
f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
||||
1 << 0,
|
||||
)
|
||||
|
Reference in New Issue
Block a user