From 1272c7ce980e7111b9a182133572e2f77c36256a Mon Sep 17 00:00:00 2001 From: zhink <33270771+zhink@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:45:44 +0800 Subject: [PATCH] Fix performance degradation bug of custom_all_reduce (#2981) --- fastdeploy/distributed/communication.py | 5 +++-- .../distributed/custom_all_reduce/custom_all_reduce.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index 5311a77f9..95334f63e 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -14,10 +14,11 @@ # limitations under the License. """ +from contextlib import contextmanager, nullcontext + import paddle import paddle.distributed as dist from paddle.distributed import fleet -from contextlib import contextmanager, nullcontext from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size @@ -53,7 +54,7 @@ try: """All-reduce the input tensor across model parallel group.""" global _TP_AR if _TP_AR is not None and _TP_AR.should_custom_ar(input_): - _TP_AR.all_reduce(input_, input_) + _TP_AR.custom_all_reduce(input_) elif paddle.in_dynamic_mode(): hcg = fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group() diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index 1b7b46d9f..4f98b29c4 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -133,7 +133,7 @@ class CustomAllreduce: def should_custom_ar(self, inp: paddle.Tensor): if self.capturing: return True - inp_size = inp.numel() * inp.element_size() + inp_size = inp.shape[0] * inp.shape[1] * inp.element_size() # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: return False