Files
FastDeploy/fastdeploy/distributed/communication.py
chen 63a03ee152 [feature]2.2 custom_allreduce support cudagraph recapture (#4307)
* custom_allreduce support cudagraph recapture

* delete code

* add shut_down/restart default group
2025-09-29 18:14:21 +08:00

75 lines
2.3 KiB
Python

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from contextlib import contextmanager, nullcontext
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
_TP_AR = None
@contextmanager
def capture_custom_allreduce():
global _TP_AR
ar_context = nullcontext()
if _TP_AR is not None:
ar_context = _TP_AR.capture()
with ar_context:
yield
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
global _TP_AR
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
def custom_ar_clear_ipc_handles():
global _TP_AR
if _TP_AR is not None:
_TP_AR.clear_ipc_handles()
try:
@paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(
input_: paddle.Tensor,
group_: paddle.distributed.communication.group.Group = None,
) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group."""
global _TP_AR
if _TP_AR is not None and _TP_AR.should_custom_ar(input_):
# TODO: supports different_group custom allreduce
_TP_AR.custom_all_reduce(input_)
elif paddle.in_dynamic_mode():
if group_ is not None:
dist.all_reduce(input_, group=group_)
else:
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group)
else:
dist.all_reduce(input_)
except:
tensor_model_parallel_all_reduce = None