From e9174f25e8b62aece6083f0dbeb736147f2fa4d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 9 Dec 2025 19:36:58 +0800 Subject: [PATCH] commit (#5452) --- fastdeploy/distributed/communication.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index a85815956..922fbb3df 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -56,6 +56,8 @@ try: group_: paddle.distributed.communication.group.Group = None, ) -> paddle.Tensor: """All-reduce the input tensor across model parallel group.""" + if input_.shape[0] == 0: + return input_ global _TP_AR if _TP_AR is not None and _TP_AR.should_custom_ar(input_): # TODO: supports different_group custom allreduce @@ -90,6 +92,8 @@ try: @paddle.jit.marker.unified def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor: """All-reduce the input tensor across model parallel group on calc stream.""" + if input_.shape[0] == 0: + return input_ if paddle.in_dynamic_mode(): hcg = dist.fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group()