diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index a85815956..922fbb3df 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -56,6 +56,8 @@ try: group_: paddle.distributed.communication.group.Group = None, ) -> paddle.Tensor: """All-reduce the input tensor across model parallel group.""" + if input_.shape[0] == 0: + return input_ global _TP_AR if _TP_AR is not None and _TP_AR.should_custom_ar(input_): # TODO: supports different_group custom allreduce @@ -90,6 +92,8 @@ try: @paddle.jit.marker.unified def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor: """All-reduce the input tensor across model parallel group on calc stream.""" + if input_.shape[0] == 0: + return input_ if paddle.in_dynamic_mode(): hcg = dist.fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group()