Fix performance degradation bug of custom_all_reduce (#2981)

This commit is contained in:
zhink
2025-07-23 17:45:44 +08:00
committed by GitHub
parent 850c9d98d4
commit 1272c7ce98
2 changed files with 4 additions and 3 deletions

View File

@@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
""" """
from contextlib import contextmanager, nullcontext
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from contextlib import contextmanager, nullcontext
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
@@ -53,7 +54,7 @@ try:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
global _TP_AR global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_): if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
_TP_AR.all_reduce(input_, input_) _TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode(): elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group() mp_group = hcg.get_model_parallel_group()

View File

@@ -133,7 +133,7 @@ class CustomAllreduce:
def should_custom_ar(self, inp: paddle.Tensor): def should_custom_ar(self, inp: paddle.Tensor):
if self.capturing: if self.capturing:
return True return True
inp_size = inp.numel() * inp.element_size() inp_size = inp.shape[0] * inp.shape[1] * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16 # custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0: if inp_size % 16 != 0:
return False return False