custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag

* rename communication_op to communication
This commit is contained in:
zhink
2025-07-21 22:52:03 +08:00
committed by GitHub
parent ff4569f135
commit 0262ef7eb3
21 changed files with 88 additions and 51 deletions

View File

@@ -22,6 +22,8 @@ from typing import Any, Dict, List, Optional
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
cudaStream_t = ctypes.c_void_p
cudaStreamCaptureStatus = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
@@ -108,6 +110,14 @@ class CudaRTLibrary:
ctypes.c_uint,
],
),
Function(
"cudaStreamIsCapturing",
cudaError_t,
[
cudaStream_t,
ctypes.POINTER(cudaStreamCaptureStatus)
]
),
]
# class attribute to store the mapping from the path to the library
@@ -187,3 +197,9 @@ class CudaRTLibrary:
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
)
return devPtr
def cudaStreamIsCapturing(self, stream: cudaStream_t) -> ctypes.c_int:
is_capturing = ctypes.c_int()
self.CUDART_CHECK(
self.funcs["cudaStreamIsCapturing"](stream, is_capturing)
)
return is_capturing