mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag * rename communication_op to communication
This commit is contained in:
@@ -56,8 +56,7 @@ class CustomAllreduce:
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
self.capturing = False
|
||||
self.group = group
|
||||
|
||||
if not custom_ar:
|
||||
@@ -78,8 +77,6 @@ class CustomAllreduce:
|
||||
if world_size < 2:
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
|
||||
# Buffers memory are owned by this Python class and passed to C++.
|
||||
# Meta data composes of two parts: meta data for synchronization and a
|
||||
# temporary buffer for storing intermediate allreduce results.
|
||||
@@ -95,13 +92,13 @@ class CustomAllreduce:
|
||||
# is enough for 131072 such tuples. The largest model I've seen only
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
|
||||
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.full_nvlink = True
|
||||
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
|
||||
register_buffer(self._ptr, self.buffer_ptrs)
|
||||
print("zss init custom allreduce", self._ptr)
|
||||
|
||||
_instances.append(self)
|
||||
|
||||
@staticmethod
|
||||
@@ -112,7 +109,6 @@ class CustomAllreduce:
|
||||
"""
|
||||
lib = cuda_wrapper.CudaRTLibrary()
|
||||
pointer = lib.cudaMalloc(size_in_bytes)
|
||||
# lib.cudaMemset(pointer, 2, size_in_bytes)
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
rank = dist.get_rank(group=group)
|
||||
handles = []
|
||||
@@ -135,8 +131,8 @@ class CustomAllreduce:
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
def should_custom_ar(self, inp: paddle.Tensor):
|
||||
if self.disabled:
|
||||
return False
|
||||
if self.capturing:
|
||||
return True
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
@@ -167,6 +163,19 @@ class CustomAllreduce:
|
||||
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
|
||||
return out
|
||||
|
||||
def start_capture(self):
|
||||
"""
|
||||
set CUDA graph flag: True.
|
||||
"""
|
||||
self.capturing = True
|
||||
|
||||
def stop_capture(self):
|
||||
"""
|
||||
set CUDA graph flag: False and register the graph buffers.
|
||||
"""
|
||||
self.capturing = False
|
||||
self.register_graph_buffers()
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
@@ -175,44 +184,44 @@ class CustomAllreduce:
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
self.capturing = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
self.capturing = False
|
||||
self.register_graph_buffers()
|
||||
|
||||
def register_graph_buffers(self):
|
||||
"""
|
||||
Register the graph buffers collected CUDA graph during capture.
|
||||
"""
|
||||
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
|
||||
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
|
||||
all_data[self.rank] = [handle, offset]
|
||||
all_datas = []
|
||||
all_data = [handle, offset]
|
||||
|
||||
ranks = sorted(dist.get_process_group_ranks(group=self.group))
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
|
||||
dist.all_gather_object(all_datas, all_data, group=self.group)
|
||||
|
||||
# Unpack list of tuples to tuple of lists.
|
||||
handles = [d[0] for d in all_data] # type: ignore
|
||||
offsets = [d[1] for d in all_data] # type: ignore
|
||||
handles = [d[0] for d in all_datas]
|
||||
offsets = [d[1] for d in all_datas]
|
||||
register_graph_buffers(self._ptr, handles, offsets)
|
||||
|
||||
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
|
||||
"""The main allreduce API that provides support for cuda graph."""
|
||||
# When custom allreduce is disabled, this will be None.
|
||||
if self.disabled or not self.should_custom_ar(input):
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if paddle.cuda.is_current_stream_capturing():
|
||||
return self.all_reduce(input, registered=True)
|
||||
if self.capturing:
|
||||
lib = cuda_wrapper.CudaRTLibrary()
|
||||
stream = paddle.device.current_stream()
|
||||
stream_capturing = lib.cudaStreamIsCapturing(stream)
|
||||
if stream_capturing.value == 1:
|
||||
# 1 is cudaStreamCaptureStatusActive: The stream is capturing.
|
||||
return self.all_reduce(input, input, registered=True)
|
||||
else:
|
||||
# If warm up, mimic the allocation pattern since custom
|
||||
# allreduce is out-of-place.
|
||||
return paddle.empty_like(input)
|
||||
else:
|
||||
return self.all_reduce(input, registered=False)
|
||||
return self.all_reduce(input, input, registered=False)
|
||||
|
||||
def close(self):
|
||||
if not self.disabled and self._ptr:
|
||||
if self._ptr:
|
||||
dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)
|
||||
|
Reference in New Issue
Block a user