Fix performance degradation bug of custom_all_reduce (#2981)

This commit is contained in:
zhink
2025-07-23 17:45:44 +08:00
committed by GitHub
parent 850c9d98d4
commit 1272c7ce98
2 changed files with 4 additions and 3 deletions

View File

@@ -133,7 +133,7 @@ class CustomAllreduce:
def should_custom_ar(self, inp: paddle.Tensor):
if self.capturing:
return True
inp_size = inp.numel() * inp.element_size()
inp_size = inp.shape[0] * inp.shape[1] * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False