mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 13:22:23 +08:00
140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
"""
|
||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
"""
|
||
|
||
from dataclasses import dataclass
|
||
from typing import Callable, Dict, Optional
|
||
|
||
import paddle.device.cuda.graphs as graphs
|
||
import paddle.nn.layer
|
||
|
||
from fastdeploy.config import FDConfig
|
||
from fastdeploy.utils import get_logger
|
||
|
||
logger = get_logger("cudagrpah_piecewise_backend",
|
||
"cudagraph_piecewise_backend.log")
|
||
|
||
|
||
@dataclass
|
||
class ConcreteSizeEntry:
|
||
""" Record the concrete information corresponding to the current batch size """
|
||
# Concrete batch size
|
||
runtime_bs: int
|
||
# The size is in cudagraph_capture_sizes
|
||
use_cudagraph: bool = True
|
||
# Has runtime-bs been captured before
|
||
captured: bool = False
|
||
|
||
# Need to be captured callable object(dynamic graph or static grpah backend)
|
||
runnable: Callable = None # type: ignore
|
||
# Number of completed warmups
|
||
num_finished_warmup: int = 0
|
||
# Captured cuda graph object corresponding to the current batch size
|
||
cuda_graph: Optional[graphs.CUDAGraph] = None
|
||
# Output buffer of cudagraph
|
||
output_buffer: Optional[paddle.Tensor] = None
|
||
|
||
# for cudagraph debugging, track the input addresses
|
||
# during capture, and check if they are the same during replay
|
||
input_addresses: Optional[list[int]] = None
|
||
|
||
|
||
class CudaGraphPiecewiseBackend:
|
||
""" """
|
||
|
||
def __init__(
|
||
self,
|
||
fd_config: FDConfig,
|
||
runnable: Callable,
|
||
):
|
||
self.fd_config = fd_config
|
||
self.runnable = runnable
|
||
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
|
||
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
||
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
|
||
|
||
# runtime_bs -> ConcreteSizeEntry
|
||
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
|
||
|
||
for shape in self.cudagraph_capture_sizes:
|
||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||
runtime_bs=shape)
|
||
|
||
print("[CUDA GRAPH] Created all batch size entry ")
|
||
|
||
def __call__(self, **kwargs):
|
||
# Get batch size
|
||
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
|
||
batch_size = ids_remove_padding.shape[0]
|
||
|
||
padding_batch_size = self.batch_size_to_captured_size[batch_size]
|
||
# print(
|
||
# f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
|
||
# f"The padded batch size is :{padding_batch_size}"
|
||
# )
|
||
|
||
entry = self.concrete_size_entries.get(padding_batch_size)
|
||
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
|
||
if entry.runnable is None:
|
||
entry.runnable = self.runnable
|
||
# print(
|
||
# f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
|
||
# )
|
||
|
||
if not entry.use_cudagraph:
|
||
return entry.runnable(**kwargs)
|
||
|
||
# Capture a new cuda graph
|
||
if entry.cuda_graph is None:
|
||
# Warmup the model
|
||
for n in range(entry.num_finished_warmup, self.warm_up_size):
|
||
entry.num_finished_warmup += 1
|
||
entry.runnable(**kwargs)
|
||
# print(
|
||
# "[CUDA GRAPH] Warm up for batch size ",
|
||
# f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
|
||
# )
|
||
|
||
# Store input addresses for debug
|
||
input_addresses = [
|
||
x.data_ptr() for (_, x) in kwargs.items()
|
||
if isinstance(x, paddle.Tensor)
|
||
]
|
||
entry.input_addresses = input_addresses
|
||
|
||
new_grpah = graphs.CUDAGraph()
|
||
paddle.device.synchronize()
|
||
|
||
# Capture
|
||
new_grpah.capture_begin()
|
||
output = entry.runnable(**kwargs)
|
||
new_grpah.capture_end()
|
||
|
||
# Store output buffer
|
||
entry.cuda_graph = new_grpah
|
||
entry.output_buffer = paddle.zeros_like(output)
|
||
output._share_buffer_to(entry.output_buffer)
|
||
output._clear
|
||
|
||
paddle.device.synchronize()
|
||
# print(
|
||
# f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
|
||
# )
|
||
|
||
# Replay
|
||
entry.cuda_graph.replay()
|
||
# print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
|
||
return entry.output_buffer
|