From b89180f1cdd9fe00bb7607ed67d111c83eef1902 Mon Sep 17 00:00:00 2001 From: zhink <33270771+zhink@users.noreply.github.com> Date: Wed, 9 Jul 2025 16:00:27 +0800 Subject: [PATCH] [Feature] support custom all-reduce (#2758) * [Feature] support custom all-reduce * add vllm adapted --- custom_ops/gpu_ops/cpp_extensions.cc | 45 ++ .../gpu_ops/custom_all_reduce/all_reduce.cu | 165 ++++++ .../gpu_ops/custom_all_reduce/all_reduce.cuh | 526 ++++++++++++++++++ custom_ops/setup_ops.py | 1 + docs/parameters.md | 1 + docs/zh/parameters.md | 1 + fastdeploy/config.py | 2 + fastdeploy/distributed/communication_op.py | 19 +- .../distributed/custom_all_reduce/__init__.py | 17 + .../custom_all_reduce/cuda_wrapper.py | 167 ++++++ .../custom_all_reduce/custom_all_reduce.py | 226 ++++++++ fastdeploy/engine/args_utils.py | 15 + fastdeploy/engine/config.py | 3 + fastdeploy/engine/engine.py | 1 + fastdeploy/worker/gpu_worker.py | 3 + fastdeploy/worker/worker_process.py | 4 + 16 files changed, 1194 insertions(+), 2 deletions(-) create mode 100644 custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu create mode 100644 custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh create mode 100644 fastdeploy/distributed/custom_all_reduce/__init__.py create mode 100644 fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py create mode 100644 fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 9927f31a9..35e02e014 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -493,6 +493,31 @@ paddle::Tensor FusedHadamardQuantFp8Func( const float scale); #endif +int64_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, + paddle::Tensor& rank_data, int64_t rank, bool full_nvlink); + +void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out, + int64_t reg_buffer, int64_t reg_buffer_sz_bytes); + +void dispose(int64_t _fa); + +int64_t meta_size(); + +void register_buffer(int64_t _fa, const std::vector& fake_ipc_ptrs); + +std::tuple, std::vector> get_graph_buffer_ipc_meta(int64_t _fa); + +void register_graph_buffers(int64_t _fa, + const std::vector>& handles, + const std::vector>& offsets); + +std::tuple allocate_shared_buffer_and_handle( + int64_t size); + +int64_t open_mem_handle(paddle::Tensor& mem_handle); + +void free_shared_buffer(int64_t buffer); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), @@ -785,4 +810,24 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); #endif + + m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function"); + + m.def("all_reduce", &all_reduce, "all reduce function"); + + m.def("dispose", &dispose, "del function for python"); + + m.def("meta_size", &meta_size, "meta_size function for Signal struct"); + + m.def("register_buffer", ®ister_buffer, "register ipc buffer"); + + m.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); + + m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle"); + + m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer"); + + m.def("open_mem_handle", &open_mem_handle, "open_mem_handle"); + + m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta"); } diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu new file mode 100644 index 000000000..7c6d4cec7 --- /dev/null +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cu @@ -0,0 +1,165 @@ +// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu + +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "all_reduce.cuh" + +// Fake pointer type, must match fptr_t type in ops.h. +// We use this type alias to indicate when pointers are passed in as int64_t. +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +fptr_t init_custom_all_reduce(const std::vector& fake_ipc_ptrs, + paddle::Tensor& rank_data, int64_t rank, + bool full_nvlink) { + int world_size = fake_ipc_ptrs.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + paddle::Signal* ipc_ptrs[8]; + for (int i = 0; i < world_size; i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(), + rank_data.numel(), rank, world_size, + full_nvlink); +} + +/** + * Performs an out-of-place allreduce and stores result in out. + * + * If _reg_buffer is null, assumes inp.data() is already IPC-registered. + * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first + * copied into _reg_buffer. + */ +void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out, + fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { + auto fa = reinterpret_cast(_fa); + auto stream = inp.stream(); + + auto input_size = inp.numel() * 2; + auto reg_buffer = reinterpret_cast(_reg_buffer); + if (reg_buffer) { + cudaMemcpyAsync(reg_buffer, inp.data(), input_size, + cudaMemcpyDeviceToDevice, stream); + } else { + reg_buffer = inp.data(); + } + switch (out.dtype()) { + case phi::DataType::FLOAT32: { + fa->allreduce(stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), + out.numel()); + break; + } + case phi::DataType::FLOAT16: { + fa->allreduce(stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), out.numel()); + break; + } +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + case phi::DataType::BFLOAT16: { + fa->allreduce( + stream, reinterpret_cast(reg_buffer), + reinterpret_cast(out.data()), out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void dispose(fptr_t _fa) { + delete reinterpret_cast(_fa); +} + +int64_t meta_size() { return sizeof(paddle::Signal); } + +void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { + auto fa = reinterpret_cast(_fa); + void* ipc_ptrs[8]; + for (int i = 0; i < fake_ipc_ptrs.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + } + fa->register_buffer(ipc_ptrs); +} + +// Use vector to represent byte data for python binding compatibility. +std::tuple, std::vector> +get_graph_buffer_ipc_meta(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); + std::vector bytes(handle.begin(), handle.end()); + return std::make_tuple(bytes, offsets); +} + +// Use vector to represent byte data for python binding compatibility. +void register_graph_buffers(fptr_t _fa, + const std::vector>& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + std::vector bytes; + bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + bytes.reserve(handles.size()); + fa->register_graph_buffers(bytes, offsets); +} + +std::tuple allocate_shared_buffer_and_handle( + int64_t size) { + + auto device_index = phi::backends::gpu::GetCurrentDeviceId(); + void* buffer; + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; + auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream(); + CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + + // Allocate buffer + CUDACHECK(cudaMalloc((void**)&buffer, size)); + CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream)); + CUDACHECK(cudaStreamSynchronize(stream)); + CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + + // Create IPC memhandle for the allocated buffer. + // Will use it in open_mem_handle. + auto handle = + paddle::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index)); + CUDACHECK( + cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer)); + + return std::make_tuple(reinterpret_cast(buffer), handle); +} + + +fptr_t open_mem_handle(paddle::Tensor& mem_handle) { + void* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle( + (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()), + cudaIpcMemLazyEnablePeerAccess)); + return reinterpret_cast(ipc_ptr); +} + +void free_shared_buffer(fptr_t buffer) { + CUDACHECK(cudaFree(reinterpret_cast(buffer))); +} diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh new file mode 100644 index 000000000..2dd52871a --- /dev/null +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh @@ -0,0 +1,526 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace paddle { + +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; +struct Signal { + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; +}; + +struct __align__(16) RankData { + const void* __restrict__ ptrs[8]; +}; + +struct __align__(16) RankSignals { + Signal* signals[8]; +}; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Paddle, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { return a += b; } + +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. + if (threadIdx.x < ngpus) { + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } + } + if constexpr (is_start || need_fence) __syncthreads(); +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + multi_gpu_barrier(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + multi_gpu_barrier(sg, self_sg, rank); +} + +template +DINLINE P* get_tmp_buf(Signal* sg) { + return (P*)(((Signal*)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + multi_gpu_barrier(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + multi_gpu_barrier(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + RankSignals sg_; + // Stores an map from a pointer to its peer pointters from all ranks. + std::unordered_map buffers_; + Signal* self_sg_; + + // Stores rank data from all ranks. This is mainly for cuda graph purposes. + // For cuda graph to work, all kernel arguments must be fixed during graph + // capture time. However, the peer pointers are not known during graph capture + // time. Therefore, during capture, we increment the rank data pointer and use + // that as the argument to the kernel. The kernel arguments are stored in + // graph_unreg_buffers_. The actual peer pointers will be filled in at the + // memory pointed to by the pointers in graph_unreg_buffers_ when + // the IPC handles are exchanged between ranks. + // + // The overall process looks like this: + // 1. Graph capture. + // 2. Each rank obtains the IPC handles for each addresses used during cuda + // graph capture using get_graph_buffer_ipc_meta. + // 3. (In Python) all gather the IPC handles. + // 4. Obtain the peer pointers by opening the IPC handles, and store them in + // the rank data array at corresponding positions. + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. + * + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. + */ + CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) + : rank_(rank), + world_size_(world_size), + full_nvlink_(full_nvlink), + self_sg_(signals[rank]), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + sg_.signals[i] = signals[i]; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair> get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + data.ptrs[i] = ptrs[i]; + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[ptrs[rank_]] = d_data; + } + + // Note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. + */ + template + void allreduce(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); + +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +} // namespace paddle diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 3470d9534..ac7e5433f 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -276,6 +276,7 @@ elif paddle.is_compiled_with_cuda(): "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", + "gpu_ops/custom_all_reduce/all_reduce.cu", ] # pd_disaggregation diff --git a/docs/parameters.md b/docs/parameters.md index 21487f365..27076394c 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -36,6 +36,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```enable_static_graph_inference``` | `bool` | Whether to use static graph inference mode, default: False | | ```use_cudagraph``` | `bool` | Whether to use cuda graph, default: False | | ```max_capture_batch_size``` | `int` | When cuda graph is enabled, maximum batch size of captured cuda graph, default: 64 | +| ```enable_custom_all_reduce``` | `bool` | Enable Custom all-reduce, default: False | | ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None | | ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` | diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index a4bdd4de7..7561f5d9b 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -35,6 +35,7 @@ | ```enable_static_graph_inference```| `bool` | 是否使用静态图推理模式,默认False | | ```use_cudagraph``` | `bool` | 是否使用cuda graph,默认False | | ```max_capture_batch_size``` | `int` | 开启 cuda graph 时,捕获的 cuda graph的最大batch size,默认为64 | +| ```enable_custom_all_reduce``` | `bool` | 开启Custom all-reduce,默认False | | ```splitwise_role``` | `str` | 是否开启splitwise推理,默认值mixed, 支持参数为["mixed", "decode", "prefill"] | | ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 (仅单机PD分离需要),默认值None | | ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端,支持 `auto`、`xgrammar`、`off`, 默认为 `off` | diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 4d513a21b..ceeb7c4a8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -207,6 +207,8 @@ class ParallelConfig: guided_decoding_backend: str = None # disable any whitespace for guided decoding disable_any_whitespace: bool = True + # enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce). + enable_custom_all_reduce: str = "store_true" @dataclass diff --git a/fastdeploy/distributed/communication_op.py b/fastdeploy/distributed/communication_op.py index 70a368e98..fb397df0f 100644 --- a/fastdeploy/distributed/communication_op.py +++ b/fastdeploy/distributed/communication_op.py @@ -16,13 +16,28 @@ import paddle import paddle.distributed as dist +from paddle.distributed import fleet +from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size + +_TP_AR = None + +def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024): + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + global _TP_AR + if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda(): + from fastdeploy.distributed.custom_all_reduce import CustomAllreduce + _TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes) try: @paddle.jit.marker.unified def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor: """All-reduce the input tensor across model parallel group.""" - if paddle.in_dynamic_mode(): - hcg = dist.fleet.get_hybrid_communicate_group() + global _TP_AR + if _TP_AR is not None and _TP_AR.should_custom_ar(input_) : + _TP_AR.all_reduce(input_, input_) + elif paddle.in_dynamic_mode(): + hcg = fleet.get_hybrid_communicate_group() mp_group = hcg.get_model_parallel_group() dist.all_reduce(input_, group=mp_group) else: diff --git a/fastdeploy/distributed/custom_all_reduce/__init__.py b/fastdeploy/distributed/custom_all_reduce/__init__.py new file mode 100644 index 000000000..054074cf9 --- /dev/null +++ b/fastdeploy/distributed/custom_all_reduce/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .custom_all_reduce import CustomAllreduce + +__all__ = ["CustomAllreduce"] \ No newline at end of file diff --git a/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py b/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py new file mode 100644 index 000000000..af5cc487d --- /dev/null +++ b/fastdeploy/distributed/custom_all_reduce/cuda_wrapper.py @@ -0,0 +1,167 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith( + lib_name + ), f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind]), + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function( + "cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint] + ), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + if so_file is None: + pass + # so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var + assert so_file is not None, ( + "libcudart is not loaded in the current process, " "try setting VLLM_CUDART_SO_PATH" + ) + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK( + self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess) + ) + return devPtr diff --git a/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py new file mode 100644 index 000000000..3b6de6ea9 --- /dev/null +++ b/fastdeploy/distributed/custom_all_reduce/custom_all_reduce.py @@ -0,0 +1,226 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import contextmanager +import atexit +import ctypes +from typing import List, Optional + +import paddle +import paddle.distributed as dist +from paddle.distributed.communication.group import Group +from fastdeploy.model_executor.ops.gpu import ( + all_reduce, + dispose, + init_custom_all_reduce, + meta_size, + register_buffer, + get_graph_buffer_ipc_meta, + register_graph_buffers, +) + +from fastdeploy.distributed.custom_all_reduce import cuda_wrapper + +try: + meta_size() + custom_ar = True +except Exception: + custom_ar = False + +_instances = [] + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, group: Group, max_size: int=8192 * 1024) -> None: + """ + Args: + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + self.group = group + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + rank = dist.get_rank(group=self.group) + self.rank = rank + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + return + + if world_size < 2: + return + + self.disabled = False + + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer(group, meta_size() + max_size) + + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(group, max_size) + + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.full_nvlink = True + self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) + register_buffer(self._ptr, self.buffer_ptrs) + print("zss init custom allreduce", self._ptr) + _instances.append(self) + + @staticmethod + def create_shared_buffer(group: Group, size_in_bytes: int) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = cuda_wrapper.CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + # lib.cudaMemset(pointer, 2, size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + rank = dist.get_rank(group=group) + handles = [] + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer(group: Group, pointers: List[int], rank: Optional[int] = None) -> None: + if rank is None: + rank = dist.get_rank(group=group) + lib = cuda_wrapper.CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + def should_custom_ar(self, inp: paddle.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + def all_reduce(self, inp: paddle.Tensor, out: paddle.Tensor = None, registered: bool = False): + """Performs an out-of-place all reduce. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if out is None: + out = paddle.empty_like(inp) + if registered: + all_reduce(self._ptr, inp, out, 0, 0) + else: + all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) + return out + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = get_graph_buffer_ipc_meta(self._ptr) + all_data = [[None, None] + for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + register_graph_buffers(self._ptr, handles, offsets) + + def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]: + """The main allreduce API that provides support for cuda graph.""" + # When custom allreduce is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if paddle.cuda.is_current_stream_capturing(): + return self.all_reduce(input, registered=True) + else: + # If warm up, mimic the allocation pattern since custom + # allreduce is out-of-place. + return paddle.empty_like(input) + else: + return self.all_reduce(input, registered=False) + + def close(self): + if not self.disabled and self._ptr: + dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank) + self.free_shared_buffer(self.group, self.buffer_ptrs, rank=self.rank) + + +def _cleanup_instances(): + for instance in _instances: + instance.close() + + +atexit.register(_cleanup_instances) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2611214cf..40f8888bd 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -147,6 +147,12 @@ class EngineArgs: """ Flag to enable prefix caching. """ + + enable_custom_all_reduce: bool = False + """ + Flag to enable the custom all-reduce kernel. + """ + engine_worker_queue_port: int = 8002 """ Port for worker queue communication. @@ -421,6 +427,10 @@ class EngineArgs: type=int, default=EngineArgs.tensor_parallel_size, help="Degree of tensor parallelism.") + parallel_group.add_argument("--enable-custom-all-reduce", + action='store_true', + default=EngineArgs.enable_custom_all_reduce, + help="Flag to enable custom all-reduce.") parallel_group.add_argument( "--max-num-seqs", type=int, @@ -733,6 +743,7 @@ class EngineArgs: tensor_parallel_size=self.tensor_parallel_size, enable_expert_parallel=self.enable_expert_parallel, data_parallel_size=self.data_parallel_size, + enable_custom_all_reduce=self.enable_custom_all_reduce ) def create_engine_config(self) -> Config: @@ -755,6 +766,9 @@ class EngineArgs: assert not (self.use_cudagraph and self.enable_prefix_caching), \ "Prefix caching cannot be used with CUDA graph" + assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \ + "enable_custom_all_reduce must be used with tensor_parallel_size>1" + return Config( model_name_or_path=self.model, model_config=model_cfg, @@ -784,4 +798,5 @@ class EngineArgs: max_capture_batch_size=self.max_capture_batch_size, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, + enable_custom_all_reduce=self.enable_custom_all_reduce, ) diff --git a/fastdeploy/engine/config.py b/fastdeploy/engine/config.py index bac38bfb8..ca76fa9c5 100644 --- a/fastdeploy/engine/config.py +++ b/fastdeploy/engine/config.py @@ -447,6 +447,7 @@ class ParallelConfig: tensor_parallel_size: int = 1, data_parallel_size: int = 1, enable_expert_parallel: bool = False, + enable_custom_all_reduce: bool = False, ): """ Initialize the ParallelConfig class. @@ -462,6 +463,7 @@ class ParallelConfig: self.enable_expert_parallel = enable_expert_parallel self.expert_parallel_size = data_parallel_size self.local_data_parallel_id = 0 + self.enable_custom_all_reduce = enable_custom_all_reduce def print(self): """ @@ -587,6 +589,7 @@ class Config: max_capture_batch_size: int = 64, guided_decoding_backend: Optional[str] = None, disable_any_whitespace: bool = False, + enable_custom_all_reduce: bool = False, ): """ Initialize the Config class. diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e95d0d0b1..f5dc29540 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -1048,6 +1048,7 @@ class LLMEngine(object): self.cfg.enable_static_graph_inference, "use_cudagraph": self.cfg.use_cudagraph, "disable_any_whitespace": self.cfg.disable_any_whitespace, + "enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce, } for worker_flag, value in worker_append_flag.items(): if value: diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index ec359b04a..0386485fa 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -60,6 +60,9 @@ class GpuWorker(WorkerBase): gc.collect() paddle.device.cuda.empty_cache() + if self.parallel_config.enable_custom_all_reduce: + from fastdeploy.distributed.communication_op import use_custom_allreduce + use_custom_allreduce() else: raise RuntimeError( f"Not support device type: {self.device_config.device}") diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 400e7097c..0e933a607 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -507,6 +507,9 @@ def parse_args(): parser.add_argument("--enable_prefix_caching", action='store_true', help="enable prefix cache") + parser.add_argument("--enable-custom-all-reduce", + action='store_true', + help="enable custom all-reduce") parser.add_argument("--splitwise_role", type=str, default="mixed", @@ -659,6 +662,7 @@ def initialize_fd_config(config_or_args) -> FDConfig: parallel_config.enable_chunked_prefill = getattr(config_or_args, 'enable_chunked_prefill', False) parallel_config.max_num_batched_tokens = getattr(config_or_args, 'max_num_batched_tokens', 0) parallel_config.enable_prefix_caching = getattr(config_or_args, 'enable_prefix_caching', False) + parallel_config.enable_custom_all_reduce = getattr(config_or_args, 'enable_custom_all_reduce', False) parallel_config.use_ep = getattr(config_or_args, 'enable_expert_parallell', False) parallel_config.tensor_parallel_degree = getattr(config_or_args, 'tensor_parallel_size', 1) parallel_config.expert_parallel_degree = getattr(config_or_args, 'expert_parallel_size', 1)