mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 02:53:26 +08:00
fix import jit.marker.unified (#4622)
This commit is contained in:
@@ -77,22 +77,25 @@ except:
|
||||
from paddle.distributed.communication import stream
|
||||
from paddle.distributed.communication.reduce import ReduceOp
|
||||
|
||||
try:
|
||||
|
||||
def all_reduce(
|
||||
tensor,
|
||||
op,
|
||||
group,
|
||||
sync_op: bool = True,
|
||||
):
|
||||
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
|
||||
def all_reduce(
|
||||
tensor,
|
||||
op,
|
||||
group,
|
||||
sync_op: bool = True,
|
||||
):
|
||||
return stream.all_reduce(tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=True)
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group on calc stream."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce_custom(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group on calc stream."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
all_reduce(input_, op=ReduceOp.SUM, group=mp_group)
|
||||
else:
|
||||
dist.all_reduce(input_)
|
||||
except:
|
||||
tensor_model_parallel_all_reduce_custom = None
|
||||
|
||||
Reference in New Issue
Block a user