custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag

* rename communication_op to communication
This commit is contained in:
zhink
2025-07-21 22:52:03 +08:00
committed by GitHub
parent ff4569f135
commit 0262ef7eb3
21 changed files with 88 additions and 51 deletions

View File

@@ -56,8 +56,7 @@ class CustomAllreduce:
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
self.capturing = False
self.group = group
if not custom_ar:
@@ -78,8 +77,6 @@ class CustomAllreduce:
if world_size < 2:
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
@@ -95,13 +92,13 @@ class CustomAllreduce:
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = True
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
register_buffer(self._ptr, self.buffer_ptrs)
print("zss init custom allreduce", self._ptr)
_instances.append(self)
@staticmethod
@@ -112,7 +109,6 @@ class CustomAllreduce:
"""
lib = cuda_wrapper.CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
# lib.cudaMemset(pointer, 2, size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
rank = dist.get_rank(group=group)
handles = []
@@ -135,8 +131,8 @@ class CustomAllreduce:
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
def should_custom_ar(self, inp: paddle.Tensor):
if self.disabled:
return False
if self.capturing:
return True
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
@@ -167,6 +163,19 @@ class CustomAllreduce:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
return out
def start_capture(self):
"""
set CUDA graph flag: True.
"""
self.capturing = True
def stop_capture(self):
"""
set CUDA graph flag: False and register the graph buffers.
"""
self.capturing = False
self.register_graph_buffers()
@contextmanager
def capture(self):
"""
@@ -175,44 +184,44 @@ class CustomAllreduce:
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
self.capturing = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
self.capturing = False
self.register_graph_buffers()
def register_graph_buffers(self):
"""
Register the graph buffers collected CUDA graph during capture.
"""
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
all_datas = []
all_data = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
dist.all_gather_object(all_datas, all_data, group=self.group)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
handles = [d[0] for d in all_datas]
offsets = [d[1] for d in all_datas]
register_graph_buffers(self._ptr, handles, offsets)
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if paddle.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
if self.capturing:
lib = cuda_wrapper.CudaRTLibrary()
stream = paddle.device.current_stream()
stream_capturing = lib.cudaStreamIsCapturing(stream)
if stream_capturing.value == 1:
# 1 is cudaStreamCaptureStatusActive: The stream is capturing.
return self.all_reduce(input, input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return paddle.empty_like(input)
else:
return self.all_reduce(input, registered=False)
return self.all_reduce(input, input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if self._ptr:
dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)