Files
FastDeploy/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py
RAM 2fa173e327
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
[Executor] CUDAGraph support RL training (#3265)
* add clear graph opt backend

* cuda graph support rl

* add branch

* 1.fix dynamic_weight_manager bug 2.add clear api for CasualLM

* open test case

* fix typo

* update mkdocs.yaml

* [Docs]Update mkdocs.yml

* update test case

* use unittest in graph test case
2025-08-25 20:59:30 +08:00

149 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from dataclasses import dataclass
from typing import Callable, Dict, Optional
import paddle.nn.layer
from paddle.device.cuda import graphs
from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication import capture_custom_allreduce
from fastdeploy.utils import get_logger
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
@dataclass
class ConcreteSizeEntry:
"""Record the concrete information corresponding to the current shape(num_tokens)"""
# Concrete shape
runtime_bs: int
# The size is in cudagraph_capture_sizes
use_cudagraph: bool = True
# Has runtime-bs been captured before
captured: bool = False
# Need to be captured callable objectdynamic graph or static grpah backend
runnable: Callable = None # type: ignore
# Number of completed warmups
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
class CudaGraphPiecewiseBackend:
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
def __init__(
self,
fd_config: FDConfig,
runnable: Callable,
):
self.fd_config = fd_config
self.runnable = runnable
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
self._create_entry_dict()
def __call__(self, **kwargs):
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
)
entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
# Capture a new cuda graph
if entry.cuda_graph is None:
# Warmup the model
for n in range(entry.num_finished_warmup, self.warm_up_size):
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
# Store input addresses for debug
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses
new_grpah = graphs.CUDAGraph()
paddle.device.synchronize()
# Capture
with capture_custom_allreduce():
new_grpah.capture_begin()
output = entry.runnable(**kwargs)
new_grpah.capture_end()
# Store output buffer
entry.cuda_graph = new_grpah
entry.output_buffer = paddle.zeros_like(output)
output._share_buffer_to(entry.output_buffer)
output._clear
paddle.device.synchronize()
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer
def _create_entry_dict(self):
""" """
# Runtime real shape -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
)
def clear_graph(self):
""" """
# Clear graphs
for id, entry in self.concrete_size_entries.items():
if entry.cuda_graph:
del entry.cuda_graph
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
del self.concrete_size_entries
paddle.device.cuda.empty_cache()
# Create new entrys
self._create_entry_dict()