Add custom op declaration for all_reduce (#3473)

* add custom op declaration

* roll back try except
This commit is contained in:
Ryan
2025-08-20 20:29:58 +08:00
committed by GitHub
parent 33ff0bfe38
commit bcdfc1d6b9
4 changed files with 16 additions and 4 deletions

View File

@@ -158,9 +158,9 @@ class CustomAllreduce:
if out is None:
out = paddle.empty_like(inp)
if registered:
all_reduce(self._ptr, inp, out, 0, 0)
all_reduce(inp, out, self._ptr, 0, 0)
else:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
all_reduce(inp, out, self._ptr, self.buffer_ptrs[self.rank], self.max_size)
return out
def start_capture(self):