""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Dict, Optional import paddle.nn.layer from paddle.device.cuda import graphs from paddle.jit.dy2static.utils import CUDAGraphState from fastdeploy.config import FDConfig from fastdeploy.distributed.communication import capture_custom_allreduce from fastdeploy.utils import get_logger logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") @dataclass class ConcreteSizeEntry: """Record the concrete information corresponding to the current shape(num_tokens)""" # Concrete shape real_shape: int # The size is in cudagraph_capture_sizes use_cudagraph: bool = True # Has runtime-bs been captured before captured: bool = False # Need to be captured callable object(dynamic graph or static graph backend) runnable: Callable = None # type: ignore # Number of completed warmups num_finished_warmup: int = 0 # Captured cuda graph object corresponding to the current real shape cuda_graph: Optional[graphs.CUDAGraph] = None # Output buffer of cudagraph output_buffer: Optional[paddle.Tensor] = None class Dy2StCudaGraphManager: def __init__(self): self.state = CUDAGraphState.DISABLE self.captured_batch_size = set() self.batch_size = -1 def run_impl(self, original_run_impl, inputs, parameters, attrs): run_state = self.state prog_attrs, cuda_graph_attrs = attrs if run_state == CUDAGraphState.REPLAY: if self.batch_size not in self.captured_batch_size: run_state = CUDAGraphState.DISABLE elif run_state == CUDAGraphState.CAPTURE: self.captured_batch_size.add(self.batch_size) cuda_graph_attrs |= { "cuda_graph_state": run_state, "cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0, } return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs)) @contextmanager def run_impl_guard(self): with paddle.jit.dy2static.pir_partial_program.replace_run_impl_guard( self.run_impl, ): yield class CudaGraphPiecewiseBackend: """Manage the capture and replay of CUDA graphs at the subgraph level.""" def __init__( self, fd_config: FDConfig, runnable: Callable, ): self.fd_config = fd_config self.runnable = runnable self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size self._create_entry_dict() self.cuda_graph_manager = None if self.fd_config.graph_opt_config.graph_opt_level > 0: self.cuda_graph_manager = Dy2StCudaGraphManager() def run_static_model(self, entry: ConcreteSizeEntry, **kwargs): if not entry.captured: # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) # Store input addresses for debug input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)] entry.input_addresses = input_addresses # Capture self.cuda_graph_manager.state = CUDAGraphState.CAPTURE self.cuda_graph_manager.batch_size = entry.real_shape entry.captured = True with self.cuda_graph_manager.run_impl_guard(): entry.runnable(**kwargs) # Replay self.cuda_graph_manager.state = CUDAGraphState.REPLAY self.cuda_graph_manager.batch_size = entry.real_shape with self.cuda_graph_manager.run_impl_guard(): return entry.runnable(**kwargs) def __call__(self, **kwargs): # Get real shape(all num tokens) ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"] real_shape = ids_remove_padding.shape[0] padding_real_shape = self.real_shape_to_captured_size[real_shape] logger.debug( f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, " f"The padded shape is :{padding_real_shape}" ) entry = self.concrete_size_entries.get(padding_real_shape) assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list." if entry.runnable is None: entry.runnable = self.runnable logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}") if not entry.use_cudagraph: return entry.runnable(**kwargs) if self.fd_config.graph_opt_config.graph_opt_level > 0: return self.run_static_model(entry, **kwargs) # Capture a new cuda graph if entry.cuda_graph is None: # Warmup the model for n in range(entry.num_finished_warmup, self.warm_up_size): entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) # Store input addresses for debug input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)] entry.input_addresses = input_addresses new_grpah = graphs.CUDAGraph() paddle.device.synchronize() # Capture with capture_custom_allreduce(): new_grpah.capture_begin() output = entry.runnable(**kwargs) new_grpah.capture_end() # Store output buffer entry.cuda_graph = new_grpah entry.output_buffer = paddle.zeros_like(output) output._share_buffer_to(entry.output_buffer) output._clear paddle.device.synchronize() # For CUDAGraph debug # self._save_cudagrpah_dot_files(entry) logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}") # Replay entry.cuda_graph.replay() logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") return entry.output_buffer def _create_entry_dict(self): """ """ # Runtime real shape -> ConcreteSizeEntry self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} for shape in self.cudagraph_capture_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape) logger.info( f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." ) def clear_graph(self): """ """ # Clear graphs for id, entry in self.concrete_size_entries.items(): if entry.cuda_graph: del entry.cuda_graph logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.") del self.concrete_size_entries paddle.device.cuda.empty_cache() # Create new entrys self._create_entry_dict() def _save_cudagrpah_dot_files(self, entry): """Print CUDAGrpah to dot files""" if entry.cuda_graph: entry.cuda_graph.print_to_dot_files( f"./log/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", 1 << 0, )