From abf53b17ea5e647e3e2ac6ced68218ded2242b8b Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:04:39 +0800 Subject: [PATCH] [BugFix] Fix custom_all_reduce overflow (#5662) (#5667) * check * check * code style --- .../gpu_ops/custom_all_reduce/all_reduce.cuh | 82 ++++++++++++------- 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh index fea3d63fe..b17ece590 100644 --- a/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh +++ b/custom_ops/gpu_ops/custom_all_reduce/all_reduce.cuh @@ -18,21 +18,23 @@ #include #include -#include #include +#include #include #include #include #include -#define CUDACHECK(cmd) \ - do { \ - cudaError_t e = cmd; \ - if (e != cudaSuccess) { \ - printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ - cudaGetErrorString(e)); \ - exit(EXIT_FAILURE); \ - } \ +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) namespace paddle { @@ -188,7 +190,8 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { // semantic is used to enforce memory access order before and after this // barrier. template -DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, +DINLINE void multi_gpu_barrier(const RankSignals& sg, + Signal* self_sg, int rank) { if constexpr (!is_start) __syncthreads(); static_assert( @@ -205,10 +208,12 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; if constexpr (need_fence) { st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val); + while (ld_flag_acquire(self_counter_ptr) != val) + ; } else { st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val); + while (ld_flag_volatile(self_counter_ptr) != val) + ; } } if constexpr (is_start || need_fence) __syncthreads(); @@ -226,8 +231,12 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -249,8 +258,12 @@ DINLINE P* get_tmp_buf(Signal* sg) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, - T* __restrict__ result, int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, + RankSignals sg, + Signal* self_sg, + T* __restrict__ result, + int rank, + int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -323,7 +336,7 @@ class CustomAllreduce { // 3. (In Python) all gather the IPC handles. // 4. Obtain the peer pointers by opening the IPC handles, and store them in // the rank data array at corresponding positions. - RankData *d_rank_data_base_, *d_rank_data_end_; + RankData *d_rank_data_base_origin_, *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers std::map ipc_handles_; @@ -338,8 +351,12 @@ class CustomAllreduce { * Note: this class does not own any device memory. Any required buffers * are passed in from the constructor. */ - CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, - int rank, int world_size, bool full_nvlink = true) + CustomAllreduce(Signal** signals, + void* rank_data, + size_t rank_data_sz, + int rank, + int world_size, + bool full_nvlink = true) : rank_(rank), world_size_(world_size), full_nvlink_(full_nvlink), @@ -349,6 +366,7 @@ class CustomAllreduce { for (int i = 0; i < world_size_; i++) { sg_.signals[i] = signals[i]; } + d_rank_data_base_origin_ = d_rank_data_base_; } char* open_ipc_handle(const void* ipc_handle) { @@ -405,6 +423,7 @@ class CustomAllreduce { CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); buffers_[ptrs[rank_]] = d_data; + d_rank_data_base_origin_ = d_rank_data_base_; } // Note: when registering graph buffers, we intentionally choose to not @@ -434,7 +453,8 @@ class CustomAllreduce { } } } - CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + CUDACHECK(cudaMemcpy(d_rank_data_base_, + rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); d_rank_data_base_ += num_buffers; @@ -451,8 +471,12 @@ class CustomAllreduce { * guess is that too many SMs will cause contention on NVLink bus. */ template - void allreduce(cudaStream_t stream, T* input, T* output, int size, - int threads = 512, int block_limit = 36) { + void allreduce(cudaStream_t stream, + T* input, + T* output, + int size, + int threads = 512, + int block_limit = 36) { auto d = packed_t::P::size; if (size % d != 0) throw std::runtime_error( @@ -483,9 +507,9 @@ class CustomAllreduce { size /= d; auto bytes = size * sizeof(typename packed_t::P); int blocks = std::min(block_limit, (size + threads - 1) / threads); -#define KL(ngpus, name) \ - name<<>>(ptrs, sg_, self_sg_, output, \ - rank_, size); +#define KL(ngpus, name) \ + name<<>>( \ + ptrs, sg_, self_sg_, output, rank_, size); #define REDUCE_CASE(ngpus) \ case ngpus: { \ @@ -517,15 +541,15 @@ class CustomAllreduce { #undef KL } - void clear_ipc_handles(){ + void clear_ipc_handles() { for (auto [_, ptr] : ipc_handles_) { CUDACHECK(cudaIpcCloseMemHandle(ptr)); } + ipc_handles_.clear(); + d_rank_data_base_ = d_rank_data_base_origin_; } - ~CustomAllreduce() { - clear_ipc_handles(); - } + ~CustomAllreduce() { clear_ipc_handles(); } }; } // namespace paddle