[Excutor] Change cudagraph hashkey from batch size to num_tokens (#3454)

This commit is contained in:
Jundong Liu
2025-08-18 16:16:48 +08:00
committed by GitHub
parent ea4a3b479c
commit 70ee910cd5
2 changed files with 27 additions and 27 deletions

View File

@@ -29,9 +29,9 @@ logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.
@dataclass
class ConcreteSizeEntry:
"""Record the concrete information corresponding to the current batch size"""
"""Record the concrete information corresponding to the current shape(num_tokens)"""
# Concrete batch size
# Concrete shape
runtime_bs: int
# The size is in cudagraph_capture_sizes
use_cudagraph: bool = True
@@ -42,7 +42,7 @@ class ConcreteSizeEntry:
runnable: Callable = None # type: ignore
# Number of completed warmups
num_finished_warmup: int = 0
# Captured cuda graph object corresponding to the current batch size
# Captured cuda graph object corresponding to the current real shape
cuda_graph: Optional[graphs.CUDAGraph] = None
# Output buffer of cudagraph
output_buffer: Optional[paddle.Tensor] = None
@@ -60,33 +60,33 @@ class CudaGraphPiecewiseBackend:
self.runnable = runnable
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.batch_size_to_captured_size = fd_config.graph_opt_config.batch_size_to_captured_size
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
# Runtime batch size -> ConcreteSizeEntry
# Runtime real shape -> ConcreteSizeEntry
self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {}
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(runtime_bs=shape)
logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all batch sizes entry."
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
)
def __call__(self, **kwargs):
# Get batch size
# Get real shape(all num tokens)
ids_remove_padding: paddle.Tensor = kwargs["ids_remove_padding"]
batch_size = ids_remove_padding.shape[0]
padding_batch_size = self.batch_size_to_captured_size[batch_size]
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, "
f"The padded batch size is :{padding_batch_size}"
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
)
entry = self.concrete_size_entries.get(padding_batch_size)
assert entry is not None, f"Batch size:{padding_batch_size} is not in cuda graph capture list."
entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
@@ -98,7 +98,7 @@ class CudaGraphPiecewiseBackend:
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {padding_batch_size}, "
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
@@ -122,9 +122,9 @@ class CudaGraphPiecewiseBackend:
output._clear
paddle.device.synchronize()
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
return entry.output_buffer