Fix performance degradation bug of custom_all_reduce (#2981)

This commit is contained in:
zhink
2025-07-23 17:45:44 +08:00
committed by GitHub
parent 850c9d98d4
commit 1272c7ce98
2 changed files with 4 additions and 3 deletions

View File

@@ -14,10 +14,11 @@
# limitations under the License.
"""
from contextlib import contextmanager, nullcontext
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from contextlib import contextmanager, nullcontext
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
@@ -53,7 +54,7 @@ try:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
_TP_AR.all_reduce(input_, input_)
_TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()

View File

@@ -133,7 +133,7 @@ class CustomAllreduce:
def should_custom_ar(self, inp: paddle.Tensor):
if self.capturing:
return True
inp_size = inp.numel() * inp.element_size()
inp_size = inp.shape[0] * inp.shape[1] * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False