diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 98d118896..7190639cb 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -201,6 +201,8 @@ class ParallelConfig: # disable any whitespace for guided decoding self.disable_any_whitespace: bool = True self.pod_ip: str = None + # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). + self.enable_custom_all_reduce: bool = False for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) @@ -213,8 +215,6 @@ class ParallelConfig: self.moe_phase = MoEPhase.DECODER else: raise NotImplementedError - # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). - self.enable_custom_all_reduce: bool = False # pd_disaggregation use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0)) diff --git a/fastdeploy/distributed/communication_op.py b/fastdeploy/distributed/communication.py similarity index 88% rename from fastdeploy/distributed/communication_op.py rename to fastdeploy/distributed/communication.py index a54e58f87..5311a77f9 100644 --- a/fastdeploy/distributed/communication_op.py +++ b/fastdeploy/distributed/communication.py @@ -17,12 +17,23 @@ import paddle import paddle.distributed as dist from paddle.distributed import fleet +from contextlib import contextmanager, nullcontext from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size _TP_AR = None +@contextmanager +def capture_custom_allreduce(): + global _TP_AR + ar_context = nullcontext() + if _TP_AR is not None: + ar_context = _TP_AR.capture() + with ar_context: + yield + + def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): hcg = fleet.get_hybrid_communicate_group() model_parallel_group = hcg.get_model_parallel_group() diff --git a/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py b/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py index 22195364b..ac321a589 100644 --- a/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py +++ b/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py @@ -22,6 +22,8 @@ from typing import Any, Dict, List, Optional cudaError_t = ctypes.c_int cudaMemcpyKind = ctypes.c_int +cudaStream_t = ctypes.c_void_p +cudaStreamCaptureStatus = ctypes.c_int class cudaIpcMemHandle_t(ctypes.Structure): @@ -108,6 +110,14 @@ class CudaRTLibrary: ctypes.c_uint, ], ), + Function( + "cudaStreamIsCapturing", + cudaError_t, + [ + cudaStream_t, + ctypes.POINTER(cudaStreamCaptureStatus) + ] + ), ] # class attribute to store the mapping from the path to the library @@ -187,3 +197,9 @@ class CudaRTLibrary: self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess) ) return devPtr + def cudaStreamIsCapturing(self, stream: cudaStream_t) -> ctypes.c_int: + is_capturing = ctypes.c_int() + self.CUDART_CHECK( + self.funcs["cudaStreamIsCapturing"](stream, is_capturing) + ) + return is_capturing diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py index 818b8bd98..1b7b46d9f 100644 --- a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -56,8 +56,7 @@ class CustomAllreduce: is bind to a unique device, and all communicators in this group are in the same node. """ - self._IS_CAPTURING = False - self.disabled = True + self.capturing = False self.group = group if not custom_ar: @@ -78,8 +77,6 @@ class CustomAllreduce: if world_size < 2: return - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. # Meta data composes of two parts: meta data for synchronization and a # temporary buffer for storing intermediate allreduce results. @@ -95,13 +92,13 @@ class CustomAllreduce: # is enough for 131072 such tuples. The largest model I've seen only # needs less than 10000 of registered tuples. self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8) + self.max_size = max_size - self.rank = rank self.world_size = world_size self.full_nvlink = True self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) register_buffer(self._ptr, self.buffer_ptrs) - print("zss init custom allreduce", self._ptr) + _instances.append(self) @staticmethod @@ -112,7 +109,6 @@ class CustomAllreduce: """ lib = cuda_wrapper.CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) - # lib.cudaMemset(pointer, 2, size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) rank = dist.get_rank(group=group) handles = [] @@ -135,8 +131,8 @@ class CustomAllreduce: lib.cudaFree(ctypes.c_void_p(pointers[rank])) def should_custom_ar(self, inp: paddle.Tensor): - if self.disabled: - return False + if self.capturing: + return True inp_size = inp.numel() * inp.element_size() # custom allreduce requires input byte size to be multiples of 16 if inp_size % 16 != 0: @@ -167,6 +163,19 @@ class CustomAllreduce: all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) return out + def start_capture(self): + """ + set CUDA graph flag: True. + """ + self.capturing = True + + def stop_capture(self): + """ + set CUDA graph flag: False and register the graph buffers. + """ + self.capturing = False + self.register_graph_buffers() + @contextmanager def capture(self): """ @@ -175,44 +184,44 @@ class CustomAllreduce: It records all the buffer addresses used in the CUDA graph. """ try: - self._IS_CAPTURING = True + self.capturing = True yield finally: - self._IS_CAPTURING = False - if not self.disabled: - self.register_graph_buffers() + self.capturing = False + self.register_graph_buffers() def register_graph_buffers(self): + """ + Register the graph buffers collected CUDA graph during capture. + """ handle, offset = get_graph_buffer_ipc_meta(self._ptr) - all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] - all_data[self.rank] = [handle, offset] + all_datas = [] + all_data = [handle, offset] - ranks = sorted(dist.get_process_group_ranks(group=self.group)) - for i, rank in enumerate(ranks): - dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") + dist.all_gather_object(all_datas, all_data, group=self.group) - # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore + handles = [d[0] for d in all_datas] + offsets = [d[1] for d in all_datas] register_graph_buffers(self._ptr, handles, offsets) def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]: """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. - if self.disabled or not self.should_custom_ar(input): - return None - if self._IS_CAPTURING: - if paddle.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) + if self.capturing: + lib = cuda_wrapper.CudaRTLibrary() + stream = paddle.device.current_stream() + stream_capturing = lib.cudaStreamIsCapturing(stream) + if stream_capturing.value == 1: + # 1 is cudaStreamCaptureStatusActive: The stream is capturing. + return self.all_reduce(input, input, registered=True) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return paddle.empty_like(input) else: - return self.all_reduce(input, registered=False) + return self.all_reduce(input, input, registered=False) def close(self): - if not self.disabled and self._ptr: + if self._ptr: dispose(self._ptr) self._ptr = 0 self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index ba3ce8312..546d9de52 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -898,6 +898,5 @@ class EngineArgs: graph_optimization_config=graph_opt_cfg, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, - enable_custom_all_reduce=self.enable_custom_all_reduce, enable_logprob=self.enable_logprob, ) diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index 96c860d87..7c4473177 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -656,7 +656,6 @@ class Config: reasoning_parser: str = None, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, - enable_custom_all_reduce: bool = False, enable_logprob: bool = False, ): """ diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index d0e6defae..5b555465e 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1024,7 +1024,7 @@ class LLMEngine: "do_profile": self.do_profile, "dynamic_load_weight": self.cfg.model_config.dynamic_load_weight, "disable_any_whitespace": self.cfg.disable_any_whitespace, - "enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce, + "enable_custom_all_reduce": self.cfg.parallel_config.enable_custom_all_reduce, "enable_logprob": self.cfg.enable_logprob, "enable_mm": self.cfg.enable_mm, } diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index bbe95feb4..c93a3c5a4 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -22,6 +22,7 @@ from paddle.device.cuda import graphs from fastdeploy.config import FDConfig from fastdeploy.utils import get_logger +from fastdeploy.distributed.communication import capture_custom_allreduce logger = get_logger("cudagrpah_piecewise_backend", "cudagraph_piecewise_backend.log") @@ -109,9 +110,11 @@ class CudaGraphPiecewiseBackend: paddle.device.synchronize() # Capture - new_grpah.capture_begin() - output = entry.runnable(**kwargs) - new_grpah.capture_end() + with capture_custom_allreduce(): + new_grpah.capture_begin() + output = entry.runnable(**kwargs) + new_grpah.capture_end() + # Store output buffer entry.cuda_graph = new_grpah diff --git a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py index 9dd45ab95..0a6c31b06 100644 --- a/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py +++ b/fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py @@ -17,7 +17,7 @@ import paddle from paddle import nn -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.utils import ceil_div diff --git a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py index 7bc1850c7..89c0efc37 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/moe/fused_moe_method_gcu_backend.py @@ -190,7 +190,7 @@ class GCUFusedMoeMethod(MoEMethodBase): fused_moe_out = fused_moe_out.reshape_([token_num, hidden_size]) if layer.tp_size > 1: - from fastdeploy.distributed.communication_op import ( + from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce, ) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 970167ae6..622855469 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -18,7 +18,7 @@ import paddle from paddle import nn from fastdeploy.config import FDConfig -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.platforms import current_platform from .utils import _set_var_distributed, divide, get_tensor diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 67a87cc22..1ced9176a 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -20,7 +20,7 @@ from paddle.nn.quant import weight_quantize from paddleformers.utils.log import logger import fastdeploy -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.platforms import current_platform from ..utils import create_and_set_parameter, get_tensor diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index b80db3114..f7259645b 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -19,7 +19,7 @@ from paddle import nn from paddleformers.utils.log import logger import fastdeploy -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.ops.gpu import count_tokens_per_expert_func, deep_gemm diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py index 69c58a549..d086bbeef 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_marlin_backend.py @@ -18,7 +18,7 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.ops.gpu import ( MoeWna16MarlinGemmApi, tritonmoe_preprocess_func, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py index 1715cd60a..430f3104b 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py @@ -18,7 +18,7 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.layers.utils import create_and_set_parameter, get_tensor from fastdeploy.utils import ceil_div diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index e54734901..cc2932d4e 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -18,7 +18,7 @@ import paddle from paddle import nn import fastdeploy -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.utils import ceil_div from ..quantization.quant_base import QuantMethodBase diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py index 03331e46b..c320ed481 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_xpu_backend.py @@ -82,7 +82,7 @@ class XPUMoEMethod(MoEMethodBase): False, # moe group, used in deepseek ) if layer.tp_size > 1: - from fastdeploy.distributed.communication_op import ( + from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce, ) @@ -210,7 +210,7 @@ class XPUWeightOnlyMoEMethod(QuantMethodBase): False, # moe group, used in deepseek ) if layer.tp_size > 1: - from fastdeploy.distributed.communication_op import ( + from fastdeploy.distributed.communication import ( tensor_model_parallel_all_reduce, ) diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index b1ebd98ec..2c2b0efe1 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -25,7 +25,7 @@ from paddleformers.transformers import PretrainedModel from paddleformers.utils.log import logger from fastdeploy.config import FDConfig -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.activation import SiluAndMul from fastdeploy.model_executor.layers.attention.attention import Attention diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index 6a1499e20..bcf7db223 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -28,7 +28,7 @@ from paddleformers.transformers.configuration_utils import PretrainedConfig from paddleformers.utils.log import logger from fastdeploy.config import FDConfig -from fastdeploy.distributed.communication_op import tensor_model_parallel_all_reduce +from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, ) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 7dcdcbe8f..45aa96b4c 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -62,7 +62,7 @@ class GpuWorker(WorkerBase): gc.collect() paddle.device.cuda.empty_cache() if self.parallel_config.enable_custom_all_reduce: - from fastdeploy.distributed.communication_op import use_custom_allreduce + from fastdeploy.distributed.communication import use_custom_allreduce use_custom_allreduce() else: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index eba0250cc..d656f60ae 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -498,7 +498,7 @@ def parse_args(): help="enable prefix cache", ) parser.add_argument( - "--enable-custom-all-reduce", + "--enable_custom_all_reduce", action="store_true", help="enable custom all-reduce", )