mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[CudaGraph] [SOT] Support spliting static graph into piecewise graph with cuda_graph (#3478)
* support spliting static graph into piecewise graph with cuda_graph * Update fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix merge conflict * fix bug --------- Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -14,11 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import paddle.nn.layer
|
import paddle.nn.layer
|
||||||
from paddle.device.cuda import graphs
|
from paddle.device.cuda import graphs
|
||||||
|
from paddle.jit.dy2static.utils import CUDAGraphState
|
||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.distributed.communication import capture_custom_allreduce
|
from fastdeploy.distributed.communication import capture_custom_allreduce
|
||||||
@@ -48,6 +50,35 @@ class ConcreteSizeEntry:
|
|||||||
output_buffer: Optional[paddle.Tensor] = None
|
output_buffer: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Dy2StCudaGraphManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.state = CUDAGraphState.DISABLE
|
||||||
|
self.captured_batch_size = set()
|
||||||
|
self.batch_size = -1
|
||||||
|
|
||||||
|
def run_impl(self, original_run_impl, inputs, parameters, attrs):
|
||||||
|
run_state = self.state
|
||||||
|
prog_attrs, cuda_graph_attrs = attrs
|
||||||
|
if run_state == CUDAGraphState.REPLAY:
|
||||||
|
if self.batch_size not in self.captured_batch_size:
|
||||||
|
run_state = CUDAGraphState.DISABLE
|
||||||
|
elif run_state == CUDAGraphState.CAPTURE:
|
||||||
|
self.captured_batch_size.add(self.batch_size)
|
||||||
|
|
||||||
|
cuda_graph_attrs |= {
|
||||||
|
"cuda_graph_state": run_state,
|
||||||
|
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
|
||||||
|
}
|
||||||
|
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def run_impl_guard(self):
|
||||||
|
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
|
||||||
|
self.run_impl,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
class CudaGraphPiecewiseBackend:
|
class CudaGraphPiecewiseBackend:
|
||||||
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
|
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
|
||||||
|
|
||||||
@@ -64,6 +95,38 @@ class CudaGraphPiecewiseBackend:
|
|||||||
|
|
||||||
self._create_entry_dict()
|
self._create_entry_dict()
|
||||||
|
|
||||||
|
self.cuda_graph_manager = None
|
||||||
|
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||||
|
self.cuda_graph_manager = Dy2StCudaGraphManager()
|
||||||
|
|
||||||
|
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
|
||||||
|
if not entry.captured:
|
||||||
|
# Warmup the model
|
||||||
|
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||||||
|
entry.num_finished_warmup += 1
|
||||||
|
entry.runnable(**kwargs)
|
||||||
|
logger.debug(
|
||||||
|
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
|
||||||
|
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store input addresses for debug
|
||||||
|
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
|
||||||
|
entry.input_addresses = input_addresses
|
||||||
|
|
||||||
|
# Capture
|
||||||
|
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
|
||||||
|
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||||
|
entry.captured = True
|
||||||
|
with self.cuda_graph_manager.run_impl_guard():
|
||||||
|
entry.runnable(**kwargs)
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
|
||||||
|
self.cuda_graph_manager.batch_size = entry.real_shape
|
||||||
|
with self.cuda_graph_manager.run_impl_guard():
|
||||||
|
return entry.runnable(**kwargs)
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
# Get real shape(all num tokens)
|
# Get real shape(all num tokens)
|
||||||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||||||
@@ -83,6 +146,9 @@ class CudaGraphPiecewiseBackend:
|
|||||||
if not entry.use_cudagraph:
|
if not entry.use_cudagraph:
|
||||||
return entry.runnable(**kwargs)
|
return entry.runnable(**kwargs)
|
||||||
|
|
||||||
|
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||||
|
return self.run_static_model(entry, **kwargs)
|
||||||
|
|
||||||
# Capture a new cuda graph
|
# Capture a new cuda graph
|
||||||
if entry.cuda_graph is None:
|
if entry.cuda_graph is None:
|
||||||
# Warmup the model
|
# Warmup the model
|
||||||
|
Reference in New Issue
Block a user