Files
FastDeploy/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py
2025-06-29 23:29:37 +00:00

140 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Callable, Dict, Optional
import paddle.device.cuda.graphs as graphs
import paddle.nn.layer
from fastdeploy.config import FDConfig
from fastdeploy.utils import get_logger
logger = get_logger("cudagrpah_piecewise_backend",
"cudagraph_piecewise_backend.log")
@dataclass
class ConcreteSizeEntry:
""" Record the concrete information corresponding to the current batch size """
# Concrete batch size
runtime_bs: int
# The size is in cudagraph_capture_sizes
use_cudagraph: bool = True
# Has runtime-bs been captured before
captured: bool = False
# Need to be captured callable objectdynamic graph or static grpah backend
runnable: Callable = None # type: ignore
# Number of completed warmups
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current batch size
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: Optional[list[int]] = None
class CudaGraphPiecewiseBackend:
""" """
def __init__(
self,
fd_config: FDConfig,
runnable: Callable,
):
self.fd_config = fd_config
self.runnable = runnable
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
# runtime_bs -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(
runtime_bs=shape)
print("[CUDA GRAPH] Created all batch size entry ")
def __call__(self, **kwargs):
# Get batch size
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
# print(
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
# f"The padded batch size is :{padding_batch_size}"
# )
entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
# print(
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
# )
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
# print(
# "[CUDA GRAPH] Warm up for batch size ",
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
# )
# Store input addresses for debug
input_addresses = [
x.data_ptr() for (_, x) in kwargs.items()
if isinstance(x, paddle.Tensor)
]
entry.input_addresses = input_addresses
new_grpah = graphs.CUDAGraph()
paddle.device.synchronize()
# Capture
new_grpah.capture_begin()
output = entry.runnable(**kwargs)
new_grpah.capture_end()
# Store output buffer
entry.cuda_graph = new_grpah
entry.output_buffer = paddle.zeros_like(output)
output._share_buffer_to(entry.output_buffer)
output._clear
paddle.device.synchronize()
# print(
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
# )
# Replay
entry.cuda_graph.replay()
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
return entry.output_buffer