[Feature] support custom all-reduce (#2758)

* [Feature] support custom all-reduce

* add vllm adapted
This commit is contained in:
zhink
2025-07-09 16:00:27 +08:00
committed by GitHub
parent be21ef5047
commit b89180f1cd
16 changed files with 1194 additions and 2 deletions

View File

@@ -493,6 +493,31 @@ paddle::Tensor FusedHadamardQuantFp8Func(
const float scale); const float scale);
#endif #endif
int64_t init_custom_all_reduce(const std::vector<int64_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce(int64_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
int64_t reg_buffer, int64_t reg_buffer_sz_bytes);
void dispose(int64_t _fa);
int64_t meta_size();
void register_buffer(int64_t _fa, const std::vector<int64_t>& fake_ipc_ptrs);
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(int64_t _fa);
void register_graph_buffers(int64_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets);
std::tuple<int64_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size);
int64_t open_mem_handle(paddle::Tensor& mem_handle);
void free_shared_buffer(int64_t buffer);
PYBIND11_MODULE(fastdeploy_ops, m) { PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"), m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -785,4 +810,24 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func, m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function"); py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
#endif #endif
m.def("init_custom_all_reduce", &init_custom_all_reduce, "init all reduce class function");
m.def("all_reduce", &all_reduce, "all reduce function");
m.def("dispose", &dispose, "del function for python");
m.def("meta_size", &meta_size, "meta_size function for Signal struct");
m.def("register_buffer", &register_buffer, "register ipc buffer");
m.def("register_graph_buffers", &register_graph_buffers, "register_graph_buffers");
m.def("allocate_shared_buffer_and_handle", &allocate_shared_buffer_and_handle, "allocate_shared_buffer_and_handle");
m.def("free_shared_buffer", &free_shared_buffer, "free_shared_buffer");
m.def("open_mem_handle", &open_mem_handle, "open_mem_handle");
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "get_graph_buffer_ipc_meta");
} }

View File

@@ -0,0 +1,165 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "all_reduce.cuh"
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_all_reduce(const std::vector<fptr_t>& fake_ipc_ptrs,
paddle::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
int world_size = fake_ipc_ptrs.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");
paddle::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<paddle::Signal*>(fake_ipc_ptrs[i]);
}
return (fptr_t) new paddle::CustomAllreduce(ipc_ptrs, rank_data.data(),
rank_data.numel(), rank, world_size,
full_nvlink);
}
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void all_reduce(fptr_t _fa, paddle::Tensor& inp, paddle::Tensor& out,
fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto stream = inp.stream();
auto input_size = inp.numel() * 2;
auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
if (reg_buffer) {
cudaMemcpyAsync(reg_buffer, inp.data(), input_size,
cudaMemcpyDeviceToDevice, stream);
} else {
reg_buffer = inp.data();
}
switch (out.dtype()) {
case phi::DataType::FLOAT32: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
reinterpret_cast<float*>(out.data()),
out.numel());
break;
}
case phi::DataType::FLOAT16: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data()), out.numel());
break;
}
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
case phi::DataType::BFLOAT16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
}
}
void dispose(fptr_t _fa) {
delete reinterpret_cast<paddle::CustomAllreduce*>(_fa);
}
int64_t meta_size() { return sizeof(paddle::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
}
fa->register_buffer(ipc_ptrs);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets);
}
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
const std::vector<std::vector<int64_t>>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<paddle::CustomAllreduce*>(_fa);
std::vector<std::string> bytes;
bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) {
bytes.emplace_back(handles[i].begin(), handles[i].end());
}
bytes.reserve(handles.size());
fa->register_graph_buffers(bytes, offsets);
}
std::tuple<fptr_t, paddle::Tensor> allocate_shared_buffer_and_handle(
int64_t size) {
auto device_index = phi::backends::gpu::GetCurrentDeviceId();
void* buffer;
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
auto stream = paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream();
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Allocate buffer
CUDACHECK(cudaMalloc((void**)&buffer, size));
CUDACHECK(cudaMemsetAsync(buffer, 0, size, stream));
CUDACHECK(cudaStreamSynchronize(stream));
CUDACHECK(cudaThreadExchangeStreamCaptureMode(&mode));
// Create IPC memhandle for the allocated buffer.
// Will use it in open_mem_handle.
auto handle =
paddle::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, paddle::DataType::UINT8, paddle::GPUPlace(device_index));
CUDACHECK(
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data(), buffer));
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
}
fptr_t open_mem_handle(paddle::Tensor& mem_handle) {
void* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle(
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data()),
cudaIpcMemLazyEnablePeerAccess));
return reinterpret_cast<fptr_t>(ipc_ptr);
}
void free_shared_buffer(fptr_t buffer) {
CUDACHECK(cudaFree(reinterpret_cast<void*>(buffer)));
}

View File

@@ -0,0 +1,526 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include <array>
#include <limits>
#include <map>
#include <unordered_map>
#include <vector>
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
namespace paddle {
constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
using FlagType = uint32_t;
struct Signal {
alignas(128) FlagType self_counter[kMaxBlocks][8];
// Two sets of peer counters are needed for two syncs. The reason is that
// it's possible for peer GPU block to arrive at the second sync point while
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
};
struct __align__(16) RankData {
const void* __restrict__ ptrs[8];
};
struct __align__(16) RankSignals {
Signal* signals[8];
};
// like std::array, but aligned
template <typename T, int sz>
struct __align__(alignof(T) * sz) array_t {
T data[sz];
using type = T;
static constexpr int size = sz;
};
// use packed type to maximize memory efficiency
// goal: generate ld.128 and st.128 instructions
template <typename T>
struct packed_t {
// the (P)acked type for load/store
using P = array_t<T, 16 / sizeof(T)>;
// the (A)ccumulator type for reduction
using A = array_t<float, 16 / sizeof(T)>;
};
#define DINLINE __device__ __forceinline__
// scalar cast functions
DINLINE float upcast_s(half val) { return __half2float(val); }
template <typename T>
DINLINE T downcast_s(float val);
template <>
DINLINE half downcast_s(float val) {
return __float2half(val);
}
// scalar add functions
// for some reason when compiling with Paddle, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800)
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
}
return a;
}
template <typename T, int N>
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
if constexpr (std::is_same<T, float>::value) {
return val;
} else {
array_t<float, N> out;
#pragma unroll
for (int i = 0; i < N; i++) {
out.data[i] = upcast_s(val.data[i]);
}
return out;
}
}
template <typename O>
DINLINE O downcast(array_t<float, O::size> val) {
if constexpr (std::is_same<typename O::type, float>::value) {
return val;
} else {
O out;
#pragma unroll
for (int i = 0; i < O::size; i++) {
out.data[i] = downcast_s<typename O::type>(val.data[i]);
}
return out;
}
}
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#else
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
#endif
}
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
FlagType flag;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
#else
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
: "=r"(flag)
: "l"(flag_addr));
#endif
return flag;
}
static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) {
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}
static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
FlagType flag;
asm volatile("ld.volatile.global.u32 %0, [%1];"
: "=r"(flag)
: "l"(flag_addr));
return flag;
}
// is_start: whether this is the very first synchronization barrier.
// need_fence: whether a memory fence is needed. If true, a release-acquire
// semantic is used to enforce memory access order before and after this
// barrier.
template <int ngpus, bool is_start, bool need_fence = false>
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
int rank) {
if constexpr (!is_start) __syncthreads();
static_assert(
!(is_start && need_fence)); // Start barrier shouldn't need fence.
if (threadIdx.x < ngpus) {
// Increment the counter. Technically we only need one counter, but we use
// multiple per block to eliminate the need to share the counter via smem.
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
// Write the expected counter value to peer and wait for correct value from
// peer.
auto peer_counter_ptr =
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
} else {
st_flag_volatile(peer_counter_ptr, val);
while (ld_flag_volatile(self_counter_ptr) != val);
}
}
if constexpr (is_start || need_fence) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
packed_assign_add(tmp, upcast(ptrs[i][idx]));
}
return downcast<P>(tmp);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
class CustomAllreduce {
public:
int rank_;
int world_size_;
bool full_nvlink_;
RankSignals sg_;
// Stores an map from a pointer to its peer pointters from all ranks.
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i];
}
}
char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
return it->second;
}
std::pair<std::string, std::vector<int64_t>> get_graph_buffer_ipc_meta() {
auto num_buffers = graph_unreg_buffers_.size();
auto handle_sz = sizeof(cudaIpcMemHandle_t);
std::string handles(handle_sz * num_buffers, static_cast<char>(0));
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
}
return std::make_pair(handles, offsets);
}
void check_rank_data_capacity(size_t num = 1) {
if (d_rank_data_base_ + num > d_rank_data_end_)
throw std::runtime_error(
"Rank data buffer is overflowed by " +
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[ptrs[rank_]] = d_data;
}
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
// rank 1 may get the same input address for the second allreduce, but rank 2
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
} else {
rd.ptrs[j] = self_ptr;
}
}
}
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
sizeof(RankData) * num_buffers,
cudaMemcpyHostToDevice));
d_rank_data_base_ += num_buffers;
graph_unreg_buffers_.clear();
}
/**
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage); \
} else if (full_nvlink_) { \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) { \
KL(ngpus, cross_device_reduce_1stage); \
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
~CustomAllreduce() {
for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr));
}
}
};
} // namespace paddle

View File

@@ -276,6 +276,7 @@ elif paddle.is_compiled_with_cuda():
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu", "gpu_ops/noaux_tc.cu",
"gpu_ops/custom_all_reduce/all_reduce.cu",
] ]
# pd_disaggregation # pd_disaggregation

View File

@@ -36,6 +36,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```enable_static_graph_inference``` | `bool` | Whether to use static graph inference mode, default: False | | ```enable_static_graph_inference``` | `bool` | Whether to use static graph inference mode, default: False |
| ```use_cudagraph``` | `bool` | Whether to use cuda graph, default: False | | ```use_cudagraph``` | `bool` | Whether to use cuda graph, default: False |
| ```max_capture_batch_size``` | `int` | When cuda graph is enabled, maximum batch size of captured cuda graph, default: 64 | | ```max_capture_batch_size``` | `int` | When cuda graph is enabled, maximum batch size of captured cuda graph, default: 64 |
| ```enable_custom_all_reduce``` | `bool` | Enable Custom all-reduce, default: False |
| ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] | | ```splitwise_role``` | `str` | Whether to enable splitwise inference, default value: mixed, supported parameters: ["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None | | ```innode_prefill_ports``` | `str` | Internal engine startup ports for prefill instances (only required for single-machine PD separation), default: None |
| ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` | | ```guided_decoding_backend``` | `str` | Specify the guided decoding backend to use, supports `auto`, `xgrammar`, `off`, default: `off` |

View File

@@ -35,6 +35,7 @@
| ```enable_static_graph_inference```| `bool` | 是否使用静态图推理模式默认False | | ```enable_static_graph_inference```| `bool` | 是否使用静态图推理模式默认False |
| ```use_cudagraph``` | `bool` | 是否使用cuda graph默认False | | ```use_cudagraph``` | `bool` | 是否使用cuda graph默认False |
| ```max_capture_batch_size``` | `int` | 开启 cuda graph 时,捕获的 cuda graph的最大batch size默认为64 | | ```max_capture_batch_size``` | `int` | 开启 cuda graph 时,捕获的 cuda graph的最大batch size默认为64 |
| ```enable_custom_all_reduce``` | `bool` | 开启Custom all-reduce默认False |
| ```splitwise_role``` | `str` | 是否开启splitwise推理默认值mixed 支持参数为["mixed", "decode", "prefill"] | | ```splitwise_role``` | `str` | 是否开启splitwise推理默认值mixed 支持参数为["mixed", "decode", "prefill"] |
| ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 仅单机PD分离需要默认值None | | ```innode_prefill_ports``` | `str` | prefill 实例内部引擎启动端口 仅单机PD分离需要默认值None |
| ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端支持 `auto`、`xgrammar`、`off`, 默认为 `off` | | ```guided_decoding_backend``` | `str` | 指定要使用的guided decoding后端支持 `auto`、`xgrammar`、`off`, 默认为 `off` |

View File

@@ -207,6 +207,8 @@ class ParallelConfig:
guided_decoding_backend: str = None guided_decoding_backend: str = None
# disable any whitespace for guided decoding # disable any whitespace for guided decoding
disable_any_whitespace: bool = True disable_any_whitespace: bool = True
# enable the custom all-reduce kernel and fall back to NCCL(dist.all_reduce).
enable_custom_all_reduce: str = "store_true"
@dataclass @dataclass

View File

@@ -16,13 +16,28 @@
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet
from fastdeploy.distributed.parallel_state import get_tensor_model_parallel_world_size
_TP_AR = None
def use_custom_allreduce(custom_all_reduce_max_bytes: int = 8192 * 1024):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
global _TP_AR
if get_tensor_model_parallel_world_size() > 1 and paddle.is_compiled_with_cuda():
from fastdeploy.distributed.custom_all_reduce import CustomAllreduce
_TP_AR = CustomAllreduce(model_parallel_group, custom_all_reduce_max_bytes)
try: try:
@paddle.jit.marker.unified @paddle.jit.marker.unified
def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor: def tensor_model_parallel_all_reduce(input_: paddle.Tensor) -> paddle.Tensor:
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
if paddle.in_dynamic_mode(): global _TP_AR
hcg = dist.fleet.get_hybrid_communicate_group() if _TP_AR is not None and _TP_AR.should_custom_ar(input_) :
_TP_AR.all_reduce(input_, input_)
elif paddle.in_dynamic_mode():
hcg = fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group() mp_group = hcg.get_model_parallel_group()
dist.all_reduce(input_, group=mp_group) dist.all_reduce(input_, group=mp_group)
else: else:

View File

@@ -0,0 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .custom_all_reduce import CustomAllreduce
__all__ = ["CustomAllreduce"]

View File

@@ -0,0 +1,167 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int
class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]
@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]
def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(
lib_name
), f"Unexpected filename: {filename} for library {lib_name}"
return path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
# cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function("cudaMemcpy", cudaError_t, [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function(
"cudaIpcOpenMemHandle", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint]
),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
pass
# so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, (
"libcudart is not loaded in the current process, " "try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")
def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")
def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr
def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr))
return handle
def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(
self.funcs["cudaIpcOpenMemHandle"](ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)
)
return devPtr

View File

@@ -0,0 +1,226 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import atexit
import ctypes
from typing import List, Optional
import paddle
import paddle.distributed as dist
from paddle.distributed.communication.group import Group
from fastdeploy.model_executor.ops.gpu import (
all_reduce,
dispose,
init_custom_all_reduce,
meta_size,
register_buffer,
get_graph_buffer_ipc_meta,
register_graph_buffers,
)
from fastdeploy.distributed.custom_all_reduce import cuda_wrapper
try:
meta_size()
custom_ar = True
except Exception:
custom_ar = False
_instances = []
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
# max_size: max supported allreduce size
def __init__(self, group: Group, max_size: int=8192 * 1024) -> None:
"""
Args:
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
self.group = group
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
return
if world_size < 2:
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(group, meta_size() + max_size)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(group, max_size)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = paddle.empty([8 * 1024 * 1024], dtype=paddle.uint8)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.full_nvlink = True
self._ptr = init_custom_all_reduce(self.meta_ptrs, self.rank_data, rank, self.full_nvlink)
register_buffer(self._ptr, self.buffer_ptrs)
print("zss init custom allreduce", self._ptr)
_instances.append(self)
@staticmethod
def create_shared_buffer(group: Group, size_in_bytes: int) -> List[int]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib = cuda_wrapper.CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
# lib.cudaMemset(pointer, 2, size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
rank = dist.get_rank(group=group)
handles = []
dist.all_gather_object(handles, handle, group=group)
pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return pointers
@staticmethod
def free_shared_buffer(group: Group, pointers: List[int], rank: Optional[int] = None) -> None:
if rank is None:
rank = dist.get_rank(group=group)
lib = cuda_wrapper.CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
def should_custom_ar(self, inp: paddle.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
def all_reduce(self, inp: paddle.Tensor, out: paddle.Tensor = None, registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = paddle.empty_like(inp)
if registered:
all_reduce(self._ptr, inp, out, 0, 0)
else:
all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size)
return out
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = get_graph_buffer_ipc_meta(self._ptr)
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
register_graph_buffers(self._ptr, handles, offsets)
def custom_all_reduce(self, input: paddle.Tensor) -> Optional[paddle.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if paddle.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return paddle.empty_like(input)
else:
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.group, self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.group, self.buffer_ptrs, rank=self.rank)
def _cleanup_instances():
for instance in _instances:
instance.close()
atexit.register(_cleanup_instances)

View File

@@ -147,6 +147,12 @@ class EngineArgs:
""" """
Flag to enable prefix caching. Flag to enable prefix caching.
""" """
enable_custom_all_reduce: bool = False
"""
Flag to enable the custom all-reduce kernel.
"""
engine_worker_queue_port: int = 8002 engine_worker_queue_port: int = 8002
""" """
Port for worker queue communication. Port for worker queue communication.
@@ -421,6 +427,10 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help="Degree of tensor parallelism.") help="Degree of tensor parallelism.")
parallel_group.add_argument("--enable-custom-all-reduce",
action='store_true',
default=EngineArgs.enable_custom_all_reduce,
help="Flag to enable custom all-reduce.")
parallel_group.add_argument( parallel_group.add_argument(
"--max-num-seqs", "--max-num-seqs",
type=int, type=int,
@@ -733,6 +743,7 @@ class EngineArgs:
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
data_parallel_size=self.data_parallel_size, data_parallel_size=self.data_parallel_size,
enable_custom_all_reduce=self.enable_custom_all_reduce
) )
def create_engine_config(self) -> Config: def create_engine_config(self) -> Config:
@@ -755,6 +766,9 @@ class EngineArgs:
assert not (self.use_cudagraph and self.enable_prefix_caching), \ assert not (self.use_cudagraph and self.enable_prefix_caching), \
"Prefix caching cannot be used with CUDA graph" "Prefix caching cannot be used with CUDA graph"
assert not (self.tensor_parallel_size<=1 and self.enable_custom_all_reduce), \
"enable_custom_all_reduce must be used with tensor_parallel_size>1"
return Config( return Config(
model_name_or_path=self.model, model_name_or_path=self.model,
model_config=model_cfg, model_config=model_cfg,
@@ -784,4 +798,5 @@ class EngineArgs:
max_capture_batch_size=self.max_capture_batch_size, max_capture_batch_size=self.max_capture_batch_size,
guided_decoding_backend=self.guided_decoding_backend, guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace, disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
enable_custom_all_reduce=self.enable_custom_all_reduce,
) )

View File

@@ -447,6 +447,7 @@ class ParallelConfig:
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
data_parallel_size: int = 1, data_parallel_size: int = 1,
enable_expert_parallel: bool = False, enable_expert_parallel: bool = False,
enable_custom_all_reduce: bool = False,
): ):
""" """
Initialize the ParallelConfig class. Initialize the ParallelConfig class.
@@ -462,6 +463,7 @@ class ParallelConfig:
self.enable_expert_parallel = enable_expert_parallel self.enable_expert_parallel = enable_expert_parallel
self.expert_parallel_size = data_parallel_size self.expert_parallel_size = data_parallel_size
self.local_data_parallel_id = 0 self.local_data_parallel_id = 0
self.enable_custom_all_reduce = enable_custom_all_reduce
def print(self): def print(self):
""" """
@@ -587,6 +589,7 @@ class Config:
max_capture_batch_size: int = 64, max_capture_batch_size: int = 64,
guided_decoding_backend: Optional[str] = None, guided_decoding_backend: Optional[str] = None,
disable_any_whitespace: bool = False, disable_any_whitespace: bool = False,
enable_custom_all_reduce: bool = False,
): ):
""" """
Initialize the Config class. Initialize the Config class.

View File

@@ -1048,6 +1048,7 @@ class LLMEngine(object):
self.cfg.enable_static_graph_inference, self.cfg.enable_static_graph_inference,
"use_cudagraph": self.cfg.use_cudagraph, "use_cudagraph": self.cfg.use_cudagraph,
"disable_any_whitespace": self.cfg.disable_any_whitespace, "disable_any_whitespace": self.cfg.disable_any_whitespace,
"enable-custom-all-reduce": self.cfg.parallel_config.enable_custom_all_reduce,
} }
for worker_flag, value in worker_append_flag.items(): for worker_flag, value in worker_append_flag.items():
if value: if value:

View File

@@ -60,6 +60,9 @@ class GpuWorker(WorkerBase):
gc.collect() gc.collect()
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
if self.parallel_config.enable_custom_all_reduce:
from fastdeploy.distributed.communication_op import use_custom_allreduce
use_custom_allreduce()
else: else:
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")

View File

@@ -507,6 +507,9 @@ def parse_args():
parser.add_argument("--enable_prefix_caching", parser.add_argument("--enable_prefix_caching",
action='store_true', action='store_true',
help="enable prefix cache") help="enable prefix cache")
parser.add_argument("--enable-custom-all-reduce",
action='store_true',
help="enable custom all-reduce")
parser.add_argument("--splitwise_role", parser.add_argument("--splitwise_role",
type=str, type=str,
default="mixed", default="mixed",
@@ -659,6 +662,7 @@ def initialize_fd_config(config_or_args) -> FDConfig:
parallel_config.enable_chunked_prefill = getattr(config_or_args, 'enable_chunked_prefill', False) parallel_config.enable_chunked_prefill = getattr(config_or_args, 'enable_chunked_prefill', False)
parallel_config.max_num_batched_tokens = getattr(config_or_args, 'max_num_batched_tokens', 0) parallel_config.max_num_batched_tokens = getattr(config_or_args, 'max_num_batched_tokens', 0)
parallel_config.enable_prefix_caching = getattr(config_or_args, 'enable_prefix_caching', False) parallel_config.enable_prefix_caching = getattr(config_or_args, 'enable_prefix_caching', False)
parallel_config.enable_custom_all_reduce = getattr(config_or_args, 'enable_custom_all_reduce', False)
parallel_config.use_ep = getattr(config_or_args, 'enable_expert_parallell', False) parallel_config.use_ep = getattr(config_or_args, 'enable_expert_parallell', False)
parallel_config.tensor_parallel_degree = getattr(config_or_args, 'tensor_parallel_size', 1) parallel_config.tensor_parallel_degree = getattr(config_or_args, 'tensor_parallel_size', 1)
parallel_config.expert_parallel_degree = getattr(config_or_args, 'expert_parallel_size', 1) parallel_config.expert_parallel_degree = getattr(config_or_args, 'expert_parallel_size', 1)