mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] support custom all-reduce (#2758)
* [Feature] support custom all-reduce * add vllm adapted
This commit is contained in:
@@ -16,13 +16,28 @@
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from paddle.distributed import fleet
|
||||
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
|
||||
_TP_AR = None
|
||||
|
||||
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
model_parallel_group = hcg.get_model_parallel_group()
|
||||
global _TP_AR
|
||||
if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda():
|
||||
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
|
||||
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
|
||||
|
||||
try:
|
||||
@paddle.jit.marker.unified
|
||||
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
if paddle.in_dynamic_mode():
|
||||
hcg = dist.fleet.get_hybrid_communicate_group()
|
||||
global _TP_AR
|
||||
if _TP_AR is not None and _TP_AR.should_custom_ar(input_) :
|
||||
_TP_AR.all_reduce(input_, input_)
|
||||
elif paddle.in_dynamic_mode():
|
||||
hcg = fleet.get_hybrid_communicate_group()
|
||||
mp_group = hcg.get_model_parallel_group()
|
||||
dist.all_reduce(input_, group=mp_group)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user