polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -17,24 +17,31 @@
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
_TP_AR = None
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
global _TP_AR
if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda():
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_) :
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
_TP_AR.all_reduce(input_, input_)
elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group()
@@ -42,5 +49,6 @@ try:
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
except:
tensor_model_parallel_all_reduce=None
tensor_model_parallel_all_reduce = None