commit (#5452)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

This commit is contained in:
周周周
2025-12-09 19:36:58 +08:00
committed by GitHub
parent b491dcd23c
commit e9174f25e8

View File

@@ -56,6 +56,8 @@ try:
group_: paddle.distributed.communication.group.Group = None,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
if input_.shape[0] == 0:
return input_
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
# TODO: supports different_group custom allreduce
@@ -90,6 +92,8 @@ try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group on calc stream."""
if input_.shape[0] == 0:
return input_
if paddle.in_dynamic_mode():
hcg = dist.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()