From 20756cd2bb78968363a30a8e27dbc4aae32b255a Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Tue, 28 Oct 2025 22:11:03 +0800 Subject: [PATCH] fix import jit.marker.unified (#4622) --- fastdeploy/distributed/communication.py | 35 ++++++++++++++----------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/fastdeploy/distributed/communication.py b/fastdeploy/distributed/communication.py index c05671aae..a85815956 100644 --- a/fastdeploy/distributed/communication.py +++ b/fastdeploy/distributed/communication.py @@ -77,22 +77,25 @@ except: from paddle.distributed.communication import stream from paddle.distributed.communication.reduce import ReduceOp +try: -def all_reduce( - tensor, - op, - group, - sync_op: bool = True, -): - return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True) + def all_reduce( + tensor, + op, + group, + sync_op: bool = True, + ): + return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True) + @paddle.jit.marker.unified + def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor: + """All-reduce the input tensor across model parallel group on calc stream.""" + if paddle.in_dynamic_mode(): + hcg = dist.fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + all_reduce(input_, op=ReduceOp.SUM, group=mp_group) + else: + dist.all_reduce(input_) -@paddle.jit.marker.unified -def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor: - """All-reduce the input tensor across model parallel group on calc stream.""" - if paddle.in_dynamic_mode(): - hcg = dist.fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - all_reduce(input_, op=ReduceOp.SUM, group=mp_group) - else: - dist.all_reduce(input_) +except: + tensor_model_parallel_all_reduce_custom = None