mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag * rename communication_op to communication
This commit is contained in:
@@ -22,6 +22,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
cudaError_t = ctypes.c_int
|
||||
cudaMemcpyKind = ctypes.c_int
|
||||
cudaStream_t = ctypes.c_void_p
|
||||
cudaStreamCaptureStatus = ctypes.c_int
|
||||
|
||||
|
||||
class cudaIpcMemHandle_t(ctypes.Structure):
|
||||
@@ -108,6 +110,14 @@ class CudaRTLibrary:
|
||||
ctypes.c_uint,
|
||||
],
|
||||
),
|
||||
Function(
|
||||
"cudaStreamIsCapturing",
|
||||
cudaError_t,
|
||||
[
|
||||
cudaStream_t,
|
||||
ctypes.POINTER(cudaStreamCaptureStatus)
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
# class attribute to store the mapping from the path to the library
|
||||
@@ -187,3 +197,9 @@ class CudaRTLibrary:
|
||||
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
|
||||
)
|
||||
return devPtr
|
||||
def cudaStreamIsCapturing(self, stream: cudaStream_t) -> ctypes.c_int:
|
||||
is_capturing = ctypes.c_int()
|
||||
self.CUDART_CHECK(
|
||||
self.funcs["cudaStreamIsCapturing"](stream, is_capturing)
|
||||
)
|
||||
return is_capturing
|
||||
|
Reference in New Issue
Block a user