mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Fix performance degradation bug of custom_all_reduce (#2981)
This commit is contained in:
@@ -133,7 +133,7 @@ class CustomAllreduce:
|
||||
def should_custom_ar(self, inp: paddle.Tensor):
|
||||
if self.capturing:
|
||||
return True
|
||||
inp_size = inp.numel() * inp.element_size()
|
||||
inp_size = inp.shape[0] * inp.shape[1] * inp.element_size()
|
||||
# custom allreduce requires input byte size to be multiples of 16
|
||||
if inp_size % 16 != 0:
|
||||
return False
|
||||
|
Reference in New Issue
Block a user