Files
FastDeploy/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py
lizhenyun01 bab779011c
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
[CudaGraph] support cudagraph use shared pool (#4199)
* support cudagraph use shared pool

* add envs

* change CUDAGRAPH_POOL_ID to int

* change CUDAGRAPH_POOL_ID to use_memory_pool

* unify use_unique_memory_pool

* fix use_unique_memory_pool
2025-09-24 21:32:04 +08:00

243 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer
from paddle.base.core import CUDAGraph
from paddle.device.cuda import graphs
from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
from fastdeploy.utils import get_logger
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
@dataclass
class ConcreteSizeEntry:
"""Record the concrete information corresponding to the current shape(num_tokens)"""
# Concrete shape
real_shape: int
# The size is in cudagraph_capture_sizes
use_cudagraph: bool = True
# Has runtime-bs been captured before
captured: bool = False
# Need to be captured callable objectdynamic graph or static grpah backend
runnable: Callable = None # type: ignore
# Number of completed warmups
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffers of cudagraph
output_buffers: List[Optional[paddle.Tensor]] = field(default_factory=list)
class Dy2StCudaGraphManager:
def __init__(self):
self.state = jit_utils.CUDAGraphState.DISABLE
self.captured_batch_size = set()
self.batch_size = -1
def run_impl(self, original_run_impl, inputs, parameters, attrs):
run_state = self.state
prog_attrs, cuda_graph_attrs = attrs
if run_state == jit_utils.CUDAGraphState.REPLAY:
if self.batch_size not in self.captured_batch_size:
run_state = jit_utils.CUDAGraphState.DISABLE
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
self.captured_batch_size.add(self.batch_size)
cuda_graph_attrs |= {
"cuda_graph_state": run_state,
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
}
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
@contextmanager
def run_impl_guard(self):
with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard(
self.run_impl,
):
yield
class CudaGraphPiecewiseBackend:
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
def __init__(self, fd_config: FDConfig, runnable: Callable):
self.fd_config = fd_config
self.runnable = runnable
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
if self.fd_config.graph_opt_config.use_unique_memory_pool:
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
self._create_entry_dict()
self.cuda_graph_manager = None
if self.fd_config.graph_opt_config.graph_opt_level > 0:
self.cuda_graph_manager = Dy2StCudaGraphManager()
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
if not entry.captured:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses
# Capture
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
# Replay
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
self.cuda_graph_manager.batch_size = entry.real_shape
with self.cuda_graph_manager.run_impl_guard():
return entry.runnable(**kwargs)
def __call__(self, **kwargs):
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["forward_meta"].ids_remove_padding
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
)
entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
if self.fd_config.graph_opt_config.graph_opt_level > 0:
return self.run_static_model(entry, **kwargs)
# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses
new_grpah = (
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
if self.fd_config.graph_opt_config.use_unique_memory_pool
else graphs.CUDAGraph()
)
paddle.device.synchronize()
# Capture
with capture_custom_allreduce():
new_grpah.capture_begin()
outputs = entry.runnable(**kwargs)
if isinstance(outputs, paddle.Tensor):
assert outputs is not None
outputs = [outputs]
new_grpah.capture_end()
# Store output buffer
entry.cuda_graph = new_grpah
for output in outputs:
if output is not None:
output_buffer = paddle.zeros_like(output)
output._share_buffer_to(output_buffer)
output._clear
entry.output_buffers.append(output_buffer)
else:
entry.output_buffers.append(None)
paddle.device.synchronize()
# For CUDAGraph debug
# self._save_cudagrpah_dot_files(entry)
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
def _create_entry_dict(self):
""" """
# Runtime real shape -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)
logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
)
def clear_graph(self):
""" """
# Clear graphs
for id, entry in self.concrete_size_entries.items():
if entry.cuda_graph:
del entry.cuda_graph
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
del self.concrete_size_entries
paddle.device.cuda.empty_cache()
# Create new entrys
self._create_entry_dict()
def _save_cudagrpah_dot_files(self, entry):
"""Print CUDAGrpah to dot files"""
log_dir = envs.FD_LOG_DIR
if entry.cuda_graph:
entry.cuda_graph.print_to_dot_files(
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
1 << 0,
)