custom all reduce support cuda graph (#2938)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* Support enabling cuda graph and custom all reduce at the same time, and fix the overwritten custom all reduce flag

* rename communication_op to communication
This commit is contained in:
zhink
2025-07-21 22:52:03 +08:00
committed by GitHub
parent ff4569f135
commit 0262ef7eb3
21 changed files with 88 additions and 51 deletions

View File

@@ -201,6 +201,8 @@ class ParallelConfig:
# disable any whitespace for guided decoding # disable any whitespace for guided decoding
self.disable_any_whitespace: bool = True self.disable_any_whitespace: bool = True
self.pod_ip: str = None self.pod_ip: str = None
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@@ -213,8 +215,6 @@ class ParallelConfig:
self.moe_phase = MoEPhase.DECODER self.moe_phase = MoEPhase.DECODER
else: else:
raise NotImplementedError raise NotImplementedError
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
self.enable_custom_all_reduce: bool = False
# pd_disaggregation # pd_disaggregation
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))

View File

@@ -17,12 +17,23 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from contextlib import contextmanager, nullcontext
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
_TP_AR = None _TP_AR = None
@contextmanager
def capture_custom_allreduce():
global _TP_AR
ar_context = nullcontext()
if _TP_AR is not None:
ar_context = _TP_AR.capture()
with ar_context:
yield
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group() model_parallel_group = hcg.get_model_parallel_group()

View File

@@ -22,6 +22,8 @@ from typing import Any, Dict, List, Optional
cudaError_t = ctypes.c_int cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int cudaMemcpyKind = ctypes.c_int
cudaStream_t = ctypes.c_void_p
cudaStreamCaptureStatus = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure): class cudaIpcMemHandle_t(ctypes.Structure):
@@ -108,6 +110,14 @@ class CudaRTLibrary:
ctypes.c_uint, ctypes.c_uint,
], ],
), ),
Function(
"cudaStreamIsCapturing",
cudaError_t,
[
cudaStream_t,
ctypes.POINTER(cudaStreamCaptureStatus)
]
),
] ]
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
@@ -187,3 +197,9 @@ class CudaRTLibrary:
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess) self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
) )
return devPtr return devPtr
def cudaStreamIsCapturing(self, stream: cudaStream_t) -> ctypes.c_int:
is_capturing = ctypes.c_int()
self.CUDART_CHECK(
self.funcs["cudaStreamIsCapturing"](stream, is_capturing)
)
return is_capturing

View File

@@ -56,8 +56,7 @@ class CustomAllreduce:
is bind to a unique device, and all communicators in this group is bind to a unique device, and all communicators in this group
are in the same node. are in the same node.
""" """
self._IS_CAPTURING = False self.capturing = False
self.disabled = True
self.group = group self.group = group
if not custom_ar: if not custom_ar:
@@ -78,8 +77,6 @@ class CustomAllreduce:
if world_size < 2: if world_size < 2:
return return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++. # Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a # Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results. # temporary buffer for storing intermediate allreduce results.
@@ -95,13 +92,13 @@ class CustomAllreduce:
# is enough for 131072 such tuples. The largest model I've seen only # is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples. # needs less than 10000 of registered tuples.
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8) self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
self.max_size = max_size self.max_size = max_size
self.rank = rank
self.world_size = world_size self.world_size = world_size
self.full_nvlink = True self.full_nvlink = True
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
register_buffer(self._ptr, self.buffer_ptrs) register_buffer(self._ptr, self.buffer_ptrs)
print("zss init custom allreduce", self._ptr)
_instances.append(self) _instances.append(self)
@staticmethod @staticmethod
@@ -112,7 +109,6 @@ class CustomAllreduce:
""" """
lib = cuda_wrapper.CudaRTLibrary() lib = cuda_wrapper.CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes) pointer = lib.cudaMalloc(size_in_bytes)
# lib.cudaMemset(pointer, 2, size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer) handle = lib.cudaIpcGetMemHandle(pointer)
rank = dist.get_rank(group=group) rank = dist.get_rank(group=group)
handles = [] handles = []
@@ -135,8 +131,8 @@ class CustomAllreduce:
lib.cudaFree(ctypes.c_void_p(pointers[rank])) lib.cudaFree(ctypes.c_void_p(pointers[rank]))
def should_custom_ar(self, inp: paddle.Tensor): def should_custom_ar(self, inp: paddle.Tensor):
if self.disabled: if self.capturing:
return False return True
inp_size = inp.numel() * inp.element_size() inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16 # custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0: if inp_size % 16 != 0:
@@ -167,6 +163,19 @@ class CustomAllreduce:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
return out return out
def start_capture(self):
"""
set CUDA graph flag: True.
"""
self.capturing = True
def stop_capture(self):
"""
set CUDA graph flag: False and register the graph buffers.
"""
self.capturing = False
self.register_graph_buffers()
@contextmanager @contextmanager
def capture(self): def capture(self):
""" """
@@ -175,44 +184,44 @@ class CustomAllreduce:
It records all the buffer addresses used in the CUDA graph. It records all the buffer addresses used in the CUDA graph.
""" """
try: try:
self._IS_CAPTURING = True self.capturing = True
yield yield
finally: finally:
self._IS_CAPTURING = False self.capturing = False
if not self.disabled: self.register_graph_buffers()
self.register_graph_buffers()
def register_graph_buffers(self): def register_graph_buffers(self):
"""
Register the graph buffers collected CUDA graph during capture.
"""
handle, offset = get_graph_buffer_ipc_meta(self._ptr) handle, offset = get_graph_buffer_ipc_meta(self._ptr)
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_datas = []
all_data[self.rank] = [handle, offset] all_data = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group)) dist.all_gather_object(all_datas, all_data, group=self.group)
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu")
# Unpack list of tuples to tuple of lists. handles = [d[0] for d in all_datas]
handles = [d[0] for d in all_data] # type: ignore offsets = [d[1] for d in all_datas]
offsets = [d[1] for d in all_data] # type: ignore
register_graph_buffers(self._ptr, handles, offsets) register_graph_buffers(self._ptr, handles, offsets)
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]: def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
"""The main allreduce API that provides support for cuda graph.""" """The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None. if self.capturing:
if self.disabled or not self.should_custom_ar(input): lib = cuda_wrapper.CudaRTLibrary()
return None stream = paddle.device.current_stream()
if self._IS_CAPTURING: stream_capturing = lib.cudaStreamIsCapturing(stream)
if paddle.cuda.is_current_stream_capturing(): if stream_capturing.value == 1:
return self.all_reduce(input, registered=True) # 1 is cudaStreamCaptureStatusActive: The stream is capturing.
return self.all_reduce(input, input, registered=True)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
return paddle.empty_like(input) return paddle.empty_like(input)
else: else:
return self.all_reduce(input, registered=False) return self.all_reduce(input, input, registered=False)
def close(self): def close(self):
if not self.disabled and self._ptr: if self._ptr:
dispose(self._ptr) dispose(self._ptr)
self._ptr = 0 self._ptr = 0
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank) self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)

View File

@@ -898,6 +898,5 @@ class EngineArgs:
graph_optimization_config=graph_opt_cfg, graph_optimization_config=graph_opt_cfg,
guided_decoding_backend=self.guided_decoding_backend, guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace, disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_custom_all_reduce=self.enable_custom_all_reduce,
enable_logprob=self.enable_logprob, enable_logprob=self.enable_logprob,
) )

View File

@@ -656,7 +656,6 @@ class Config:
reasoning_parser: str = None, reasoning_parser: str = None,
guided_decoding_backend: Optional[str] = None, guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False, disable_any_whitespace: bool = False,
enable_custom_all_reduce: bool = False,
enable_logprob: bool = False, enable_logprob: bool = False,
): ):
""" """

View File

@@ -1024,7 +1024,7 @@ class LLMEngine:
"do_profile": self.do_profile, "do_profile": self.do_profile,
"dynamic_load_weight": self.cfg.model_config.dynamic_load_weight, "dynamic_load_weight": self.cfg.model_config.dynamic_load_weight,
"disable_any_whitespace": self.cfg.disable_any_whitespace, "disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce,
"enable_logprob": self.cfg.enable_logprob, "enable_logprob": self.cfg.enable_logprob,
"enable_mm": self.cfg.enable_mm, "enable_mm": self.cfg.enable_mm,
} }

View File

@@ -22,6 +22,7 @@ from paddle.device.cuda import graphs
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.utils import get_logger from fastdeploy.utils import get_logger
from fastdeploy.distributed.communication import capture_custom_allreduce
logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log")
@@ -109,9 +110,11 @@ class CudaGraphPiecewiseBackend:
paddle.device.synchronize() paddle.device.synchronize()
# Capture # Capture
new_grpah.capture_begin() with capture_custom_allreduce():
output = entry.runnable(**kwargs) new_grpah.capture_begin()
new_grpah.capture_end() output = entry.runnable(**kwargs)
new_grpah.capture_end()
# Store output buffer # Store output buffer
entry.cuda_graph = new_grpah entry.cuda_graph = new_grpah

View File

@@ -17,7 +17,7 @@
import paddle import paddle
from paddle import nn from paddle import nn
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div

View File

@@ -190,7 +190,7 @@ class GCUFusedMoeMethod(MoEMethodBase):
fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size]) fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size])
if layer.tp_size > 1: if layer.tp_size > 1:
from fastdeploy.distributed.communication_op import ( from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from .utils import _set_var_distributed, divide, get_tensor from .utils import _set_var_distributed, divide, get_tensor

View File

@@ -20,7 +20,7 @@ from paddle.nn.quant import weight_quantize
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.platforms import current_platform from fastdeploy.platforms import current_platform
from ..utils import create_and_set_parameter, get_tensor from ..utils import create_and_set_parameter, get_tensor

View File

@@ -19,7 +19,7 @@ from paddle import nn
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
MoeWna16MarlinGemmApi, MoeWna16MarlinGemmApi,
tritonmoe_preprocess_func, tritonmoe_preprocess_func,

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div

View File

@@ -18,7 +18,7 @@ import paddle
from paddle import nn from paddle import nn
import fastdeploy import fastdeploy
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.utils import ceil_div from fastdeploy.utils import ceil_div
from ..quantization.quant_base import QuantMethodBase from ..quantization.quant_base import QuantMethodBase

View File

@@ -82,7 +82,7 @@ class XPUMoEMethod(MoEMethodBase):
False, # moe group, used in deepseek False, # moe group, used in deepseek
) )
if layer.tp_size > 1: if layer.tp_size > 1:
from fastdeploy.distributed.communication_op import ( from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
@@ -210,7 +210,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase):
False, # moe group, used in deepseek False, # moe group, used in deepseek
) )
if layer.tp_size > 1: if layer.tp_size > 1:
from fastdeploy.distributed.communication_op import ( from fastdeploy.distributed.communication import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )

View File

@@ -25,7 +25,7 @@ from paddleformers.transformers import PretrainedModel
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.activation import SiluAndMul
from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.attention import Attention

View File

@@ -28,7 +28,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
from fastdeploy.model_executor.graph_optimization.decorator import ( from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization, support_graph_optimization,
) )

View File

@@ -62,7 +62,7 @@ class GpuWorker(WorkerBase):
gc.collect() gc.collect()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
if self.parallel_config.enable_custom_all_reduce: if self.parallel_config.enable_custom_all_reduce:
from fastdeploy.distributed.communication_op import use_custom_allreduce from fastdeploy.distributed.communication import use_custom_allreduce
use_custom_allreduce() use_custom_allreduce()
else: else:

View File

@@ -498,7 +498,7 @@ def parse_args():
help="enable prefix cache", help="enable prefix cache",
) )
parser.add_argument( parser.add_argument(
"--enable-custom-all-reduce", "--enable_custom_all_reduce",
action="store_true", action="store_true",
help="enable custom all-reduce", help="enable custom all-reduce",
) )