mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] Fix custom_all_reduce overflow (#5662)
* check * check * code style
This commit is contained in:
@@ -18,21 +18,23 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace paddle {
|
||||
@@ -188,7 +190,8 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
||||
// semantic is used to enforce memory access order before and after this
|
||||
// barrier.
|
||||
template <int ngpus, bool is_start, bool need_fence = false>
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
DINLINE void multi_gpu_barrier(const RankSignals& sg,
|
||||
Signal* self_sg,
|
||||
int rank) {
|
||||
if constexpr (!is_start) __syncthreads();
|
||||
static_assert(
|
||||
@@ -205,10 +208,12 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
||||
if constexpr (need_fence) {
|
||||
st_flag_release(peer_counter_ptr, val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
||||
while (ld_flag_acquire(self_counter_ptr) != val)
|
||||
;
|
||||
} else {
|
||||
st_flag_volatile(peer_counter_ptr, val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
||||
while (ld_flag_volatile(self_counter_ptr) != val)
|
||||
;
|
||||
}
|
||||
}
|
||||
if constexpr (is_start || need_fence) __syncthreads();
|
||||
@@ -226,8 +231,12 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
cross_device_reduce_1stage(RankData* _dp,
|
||||
RankSignals sg,
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
@@ -249,8 +258,12 @@ DINLINE P* get_tmp_buf(Signal* sg) {
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg,
|
||||
T* __restrict__ result, int rank, int size) {
|
||||
cross_device_reduce_2stage(RankData* _dp,
|
||||
RankSignals sg,
|
||||
Signal* self_sg,
|
||||
T* __restrict__ result,
|
||||
int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
@@ -323,7 +336,7 @@ class CustomAllreduce {
|
||||
// 3. (In Python) all gather the IPC handles.
|
||||
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
|
||||
// the rank data array at corresponding positions.
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
RankData *d_rank_data_base_origin_, *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void*> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char*> ipc_handles_;
|
||||
@@ -338,8 +351,12 @@ class CustomAllreduce {
|
||||
* Note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor.
|
||||
*/
|
||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||
int rank, int world_size, bool full_nvlink = true)
|
||||
CustomAllreduce(Signal** signals,
|
||||
void* rank_data,
|
||||
size_t rank_data_sz,
|
||||
int rank,
|
||||
int world_size,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(world_size),
|
||||
full_nvlink_(full_nvlink),
|
||||
@@ -349,6 +366,7 @@ class CustomAllreduce {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
sg_.signals[i] = signals[i];
|
||||
}
|
||||
d_rank_data_base_origin_ = d_rank_data_base_;
|
||||
}
|
||||
|
||||
char* open_ipc_handle(const void* ipc_handle) {
|
||||
@@ -405,6 +423,7 @@ class CustomAllreduce {
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[ptrs[rank_]] = d_data;
|
||||
d_rank_data_base_origin_ = d_rank_data_base_;
|
||||
}
|
||||
|
||||
// Note: when registering graph buffers, we intentionally choose to not
|
||||
@@ -434,7 +453,8 @@ class CustomAllreduce {
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_,
|
||||
rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
@@ -451,8 +471,12 @@ class CustomAllreduce {
|
||||
* guess is that too many SMs will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
void allreduce(cudaStream_t stream,
|
||||
T* input,
|
||||
T* output,
|
||||
int size,
|
||||
int threads = 512,
|
||||
int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
@@ -483,9 +507,9 @@ class CustomAllreduce {
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>( \
|
||||
ptrs, sg_, self_sg_, output, rank_, size);
|
||||
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
@@ -517,15 +541,15 @@ class CustomAllreduce {
|
||||
#undef KL
|
||||
}
|
||||
|
||||
void clear_ipc_handles(){
|
||||
void clear_ipc_handles() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
|
||||
ipc_handles_.clear();
|
||||
d_rank_data_base_ = d_rank_data_base_origin_;
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
clear_ipc_handles();
|
||||
}
|
||||
~CustomAllreduce() { clear_ipc_handles(); }
|
||||
};
|
||||
} // namespace paddle
|
||||
|
||||
Reference in New Issue
Block a user