[CUDAGraph]Add support for custom all-reduce operators under SOT mode (#4386)

This commit is contained in:
Ryan
2025-10-16 19:31:19 +08:00
committed by GitHub
parent 26ff2f8683
commit b87e2c6184
3 changed files with 5 additions and 3 deletions

View File

@@ -133,8 +133,9 @@ class CudaGraphPiecewiseBackend:
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
self.cuda_graph_manager.batch_size = entry.real_shape
entry.captured = True
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
with capture_custom_allreduce():
with self.cuda_graph_manager.run_impl_guard():
entry.runnable(**kwargs)
# Replay
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY