[Metax] modify wrapSize to WARP_SIZE (#5442)

This commit is contained in:
xiaozude
2025-12-09 17:44:02 +08:00
committed by GitHub
parent e397c4fba6
commit df67379bc3
4 changed files with 406 additions and 228 deletions

View File

@@ -30,8 +30,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cuda/atomic>
#include <curand_kernel.h>
#include <cuda/atomic>
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
@@ -40,20 +40,21 @@
#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
#define WARP_SIZE 32
#define FINAL_MASK 0xFFFFFFFF
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T, typename IdxT = int, typename AccT = T>
@@ -123,7 +124,8 @@ __device__ T twiddleOut(typename cub::Traits<T>::UnsignedBits bits,
return reinterpret_cast<T &>(bits);
}
template <int BitsPerPass> __host__ __device__ constexpr int calcNumBuckets() {
template <int BitsPerPass>
__host__ __device__ constexpr int calcNumBuckets() {
return 1 << BitsPerPass;
}
@@ -161,12 +163,12 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) {
if constexpr (numBuckets >= BlockSize) {
static_assert(numBuckets % BlockSize == 0);
int constexpr itemsPerThread = numBuckets / BlockSize;
typedef cub::BlockLoad<IdxT, BlockSize, itemsPerThread,
cub::BLOCK_LOAD_TRANSPOSE>
BlockLoad;
typedef cub::BlockStore<IdxT, BlockSize, itemsPerThread,
cub::BLOCK_STORE_TRANSPOSE>
BlockStore;
typedef cub::
BlockLoad<IdxT, BlockSize, itemsPerThread, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoad;
typedef cub::
BlockStore<IdxT, BlockSize, itemsPerThread, cub::BLOCK_STORE_TRANSPOSE>
BlockStore;
typedef cub::BlockScan<IdxT, BlockSize> BlockScan;
__shared__ union {
@@ -203,12 +205,19 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) {
}
template <typename T, int BitsPerPass, int NumBuckets, int Pass>
__device__ __forceinline__ void
filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
int *out_idx_buffer, T *out_scores, int64_t *out_ids,
int previous_len, Counter<T> *counter, T *histogram,
int *count_histogram, T *histogram_shm,
int *count_histogram_shm, const bool early_stop) {
__device__ __forceinline__ void filterAndHistogram(const T *in_buffer,
const int *in_idx_buffer,
T *out_buffer,
int *out_idx_buffer,
T *out_scores,
int64_t *out_ids,
int previous_len,
Counter<T> *counter,
T *histogram,
int *count_histogram,
T *histogram_shm,
int *count_histogram_shm,
const bool early_stop) {
// scan and filter
constexpr int start_bit = calcStartBit<T, BitsPerPass, Pass>();
const uint32_t mask = calcMask<T, BitsPerPass, Pass>();
@@ -220,7 +229,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
T array[VecSize];
} vec;
for (int i = (blockIdx.x * blockDim.x + threadIdx.x);
i < ceilDiv(previous_len, VecSize); i += blockDim.x * gridDim.x) {
i < ceilDiv(previous_len, VecSize);
i += blockDim.x * gridDim.x) {
vec.v = reinterpret_cast<const VecT *>(in_buffer)[i];
if constexpr (Pass == 0) {
#pragma unroll
@@ -254,8 +264,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
out_buffer[pos] = vec.array[j];
out_idx_buffer[pos] = in_idx_buffer ? in_idx_buffer[idx] : idx;
}
int bucket = calcBucket<T, BitsPerPass>(vec.array[j], start_bit,
mask, false);
int bucket = calcBucket<T, BitsPerPass>(
vec.array[j], start_bit, mask, false);
atomicAdd(histogram_shm + bucket, vec.array[j]);
atomicAdd(count_histogram_shm + bucket, 1);
}
@@ -276,12 +286,18 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
}
template <typename T, int BitsPerPass, int BlockSize, int NumBuckets, int Pass>
__global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
int *count_histograms, T *out, int64_t *ids,
T *buf1, int *idx_buf1, T *buf2,
int *idx_buf2, int *count_iter,
int *count_iter_begin, const int buf_len) {
__global__ void air_topp_sampling(Counter<T> *counters,
T *histograms,
int *count_histograms,
T *out,
int64_t *ids,
T *buf1,
int *idx_buf1,
T *buf2,
int *idx_buf2,
int *count_iter,
int *count_iter_begin,
const int buf_len) {
/***
* calc - filter - scan -find
* TODO: calc - scan - find - filter
@@ -352,10 +368,19 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
}
__syncthreads();
filterAndHistogram<T, BitsPerPass, NumBuckets, Pass>(
in_buf, in_idx_buf, out_buf, out_idx_buf, out, ids, previous_len, counter,
histogram, count_histogram, histogram_shm, count_histogram_shm,
early_stop);
filterAndHistogram<T, BitsPerPass, NumBuckets, Pass>(in_buf,
in_idx_buf,
out_buf,
out_idx_buf,
out,
ids,
previous_len,
counter,
histogram,
count_histogram,
histogram_shm,
count_histogram_shm,
early_stop);
__syncthreads();
__threadfence();
@@ -391,16 +416,16 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
__syncthreads();
// Acquire the summation of each 32 buckets
for (int i = threadIdx.x; i < NumBuckets; i += BlockSize) {
reduce_store_async(warp, warpSum + i / WARP_SIZE, histogram[i],
cg::plus<float>{});
reduce_store_async(
warp, warpSum + i / WARP_SIZE, histogram[i], cg::plus<float>{});
}
__syncthreads();
// Acquire the summation of all the 2048 buckets
if (threadIdx.x < WARP_SIZE) {
reduce_store_async(warp, blockSum, warpSum[threadIdx.x],
cg::plus<float>{});
reduce_update_async(warp, blockSum, warpSum[threadIdx.x + WARP_SIZE],
cg::plus<float>{});
reduce_store_async(
warp, blockSum, warpSum[threadIdx.x], cg::plus<float>{});
reduce_update_async(
warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], cg::plus<float>{});
}
__syncthreads();
@@ -435,9 +460,9 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
}
}
counter->sum =
current_sum - prev; // how many values still are there to find
counter->len = count_histogram[targetIdx]; // cur - prev; // number of
// values in next pass
current_sum - prev; // how many values still are there to find
counter->len = count_histogram[targetIdx]; // cur - prev; // number of
// values in next pass
typename cub::Traits<T>::UnsignedBits bucket = targetIdx;
int startBit = calcStartBit<T, BitsPerPass, Pass>();
counter->kthValueBits |= bucket << startBit;
@@ -473,10 +498,15 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
}
template <typename T, int BitsPerPass>
__global__ void air_topp_init(Counter<T> *counters, T *histograms,
int *count_histograms, const T *in, const T *ps,
curandState_t *curandstate, const int bsz,
const int vocab_size, const int buf_len,
__global__ void air_topp_init(Counter<T> *counters,
T *histograms,
int *count_histograms,
const T *in,
const T *ps,
curandState_t *curandstate,
const int bsz,
const int vocab_size,
const int buf_len,
const int num_buckets) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
@@ -517,7 +547,8 @@ struct SegmentOffsetIter {
int num_cols_;
};
template <typename T> struct Pair {
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
@@ -557,7 +588,8 @@ template <typename T> struct Pair {
inline int div_up(int a, int n) { return (a + n - 1) / n; }
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T> &p,
__device__ __forceinline__ void AddTo(Pair<T> topk[],
const Pair<T> &p,
int beam_size) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
@@ -571,8 +603,8 @@ __device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T> &p,
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
int dim, int beam_size) {
__device__ __forceinline__ void GetTopK(
Pair<T> topk[], const T *src, int idx, int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
@@ -583,8 +615,11 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
int dim, const Pair<T> &max,
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
const T *src,
int idx,
int dim,
const Pair<T> &max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
@@ -598,10 +633,15 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void
ThreadGetTopK(Pair<T> topk[], int *beam, int beam_size, const T *src,
bool *firstStep, bool *is_empty, Pair<T> *max, int dim,
const int tid) {
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
int *beam,
int beam_size,
const T *src,
bool *firstStep,
bool *is_empty,
Pair<T> *max,
int dim,
const int tid) {
if (*beam > 0) {
int length = (*beam) < beam_size ? *beam : beam_size;
if (*firstStep) {
@@ -616,22 +656,20 @@ ThreadGetTopK(Pair<T> topk[], int *beam, int beam_size, const T *src,
}
}
if (!(*is_empty)) {
GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
length);
GetTopK<T, BlockSize>(
topk + MaxLength - *beam, src, tid, dim, *max, length);
}
}
*max = topk[MaxLength - 1];
if ((*max).id == -1)
*is_empty = true;
if ((*max).id == -1) *is_empty = true;
*beam = 0;
}
}
template <typename T>
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
int delta,
int width = warpSize) {
__forceinline__ __device__ T
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = WARP_SIZE) {
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
}
@@ -650,9 +688,15 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void
BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
int *k, int *count, const int tid, const int wid, const int lane) {
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
Pair<T> topk[],
Pair<T> beam_max[],
int *beam,
int *k,
int *count,
const int tid,
const int wid,
const int lane) {
while (true) {
__syncthreads();
Pair<T> input_now = topk[0];
@@ -667,8 +711,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
: Pair<T>(std::numeric_limits<T>::min(), -1);
if (wid == 0) {
input_now = WarpReduce(input_now);
if (lane == 0)
shared_max[0] = input_now;
if (lane == 0) shared_max[0] = input_now;
}
__syncthreads();
if (tid == 0) {
@@ -679,8 +722,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
if (tid == tid_max) {
(*beam)++;
}
if (--(*k) == 0)
break;
if (--(*k) == 0) break;
__syncthreads();
if (tid == tid_max) {
@@ -690,8 +732,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
}
if (MaxLength < 5) {
if (*beam >= MaxLength)
break;
if (*beam >= MaxLength) break;
} else {
unsigned mask = 0u;
mask = __ballot_sync(FINAL_MASK, true);
@@ -721,13 +762,18 @@ __device__ inline T exponential_transform(T val, T lambda) {
}
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
curandState_t *states, T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids, T *topk_scores,
int vocab_size, int *count_iter,
int *count_iter_begin, const int k,
__global__ void KeMatrixTopPBeamTopK(const T *src,
const T *threshold,
curandState_t *states,
T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids,
T *topk_scores,
int vocab_size,
int *count_iter,
int *count_iter_begin,
const int k,
const bool need_batch_random) {
const int tid = threadIdx.x;
const int wid = tid / 32;
@@ -761,11 +807,17 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, TopPBeamTopK,
src + offset, &firststep, &is_empty,
&max, vocab_size, tid);
BlockReduce<T, MaxLength, BlockSize>(shared_max, topk, beam_max, &beam,
&top_num, &count, tid, wid, lane);
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
src + offset,
&firststep,
&is_empty,
&max,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
}
if (tid == 0) {
// printf("offset: %d\n", (int)seed_offset);
@@ -817,13 +869,18 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
}
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
__global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold,
curandState_t *states, T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids, T *topk_scores,
int vocab_size, int *count_iter,
int *count_iter_begin, const int k,
__global__ void KeMatrixTopPBeamTopKFt(const T *src,
const T *threshold,
curandState_t *states,
T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids,
T *topk_scores,
int vocab_size,
int *count_iter,
int *count_iter_begin,
const int k,
const bool need_batch_random) {
const int tid = threadIdx.x;
const int wid = tid / 32;
@@ -856,11 +913,17 @@ __global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold,
}
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, TopPBeamTopK,
src + bid * vocab_size, &firststep,
&is_empty, &max, vocab_size, tid);
BlockReduce<T, MaxLength, BlockSize>(shared_max, topk, beam_max, &beam,
&top_num, &count, tid, wid, lane);
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
&beam,
TopPBeamTopK,
src + bid * vocab_size,
&firststep,
&is_empty,
&max,
vocab_size,
tid);
BlockReduce<T, MaxLength, BlockSize>(
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
}
if (tid == 0) {
count_iter_begin[bid] = count_iter[bid];
@@ -925,14 +988,20 @@ __global__ void FillIndex(T *indices, T num_rows, T num_cols) {
}
template <typename T, int TopKMaxLength, int TopPBeamTopK>
void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold,
curandState_t *states, T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids, T *topk_scores,
int vocab_size, int *count_iter,
int *count_iter_begin, const int k,
const int bs, const bool need_batch_random,
void DispatchKeMatrixTopPBeamTopK(const T *src,
const T *threshold,
curandState_t *states,
T *top_ps,
int64_t *out_id, // topk id
T *out_val, // topk val
int64_t *topk_ids,
T *topk_scores,
int vocab_size,
int *count_iter,
int *count_iter_begin,
const int k,
const int bs,
const bool need_batch_random,
const std::string &mode,
cudaStream_t stream) {
int BlockSize = GetBlockSize(vocab_size);
@@ -940,23 +1009,43 @@ void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold,
switch (BlockSize) {
FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
<<<bs, kBlockDim, 0, stream>>>(
src, threshold, states, top_ps, out_id, out_val, topk_ids,
topk_scores, vocab_size, count_iter, count_iter_begin, k,
need_batch_random));
default:
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
<<<bs, kBlockDim, 0, stream>>>(src,
threshold,
states,
top_ps,
out_id,
out_val,
topk_ids,
topk_scores,
vocab_size,
count_iter,
count_iter_begin,
k,
need_batch_random));
default:
PD_THROW(
"the input data shape has error in the topp_beam_topk kernel.");
}
} else {
switch (BlockSize) {
FIXED_BLOCK_DIM(
KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
<<<bs, kBlockDim, 0, stream>>>(
src, threshold, states, top_ps, out_id, out_val, topk_ids,
topk_scores, vocab_size, count_iter, count_iter_begin, k,
need_batch_random));
default:
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
<<<bs, kBlockDim, 0, stream>>>(src,
threshold,
states,
top_ps,
out_id,
out_val,
topk_ids,
topk_scores,
vocab_size,
count_iter,
count_iter_begin,
k,
need_batch_random));
default:
PD_THROW(
"the input data shape has error in the topp_beam_topk kernel.");
}
}
}
@@ -978,11 +1067,17 @@ struct BlockPrefixCallbackOp {
};
template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
int64_t *out_id, const T *top_ps,
const T *threshold, curandState_t *states,
const int p_num, const int vocab_size,
const bool need_batch_random, int *count_iter,
__global__ void topp_sampling(T *sorted_probs,
int64_t *sorted_id,
T *out_val,
int64_t *out_id,
const T *top_ps,
const T *threshold,
curandState_t *states,
const int p_num,
const int vocab_size,
const bool need_batch_random,
int *count_iter,
int *count_iter_begin) {
__shared__ int stop_shared;
const int tid = threadIdx.x;
@@ -1063,11 +1158,17 @@ __global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
}
template <typename T, int BLOCK_SIZE>
__global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
T *out_val, int64_t *out_id, const T *top_ps,
const T *threshold, curandState_t *states,
const int p_num, const int vocab_size,
const bool need_batch_random, int *count_iter,
__global__ void topp_sampling_ft(T *sorted_probs,
int64_t *sorted_id,
T *out_val,
int64_t *out_id,
const T *top_ps,
const T *threshold,
curandState_t *states,
const int p_num,
const int vocab_size,
const bool need_batch_random,
int *count_iter,
int *count_iter_begin) {
__shared__ int stop_shared;
__shared__ float rand_p;
@@ -1146,7 +1247,7 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
}
}
if (!skip) {
int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0
int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0
if (lane_id == active_lane_id) {
float val = static_cast<float>(sorted_probs[offset + i_activate]);
if (val < threshold_now) {
@@ -1167,36 +1268,63 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
}
template <typename T>
void DispatchTopPSampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
int64_t *out_id, const T *top_ps, const T *threshold,
curandState_t *states, const int p_num,
const int vocab_size, const int bs,
const bool need_batch_random, int *count_iter,
int *count_iter_begin, const std::string &mode,
void DispatchTopPSampling(T *sorted_probs,
int64_t *sorted_id,
T *out_val,
int64_t *out_id,
const T *top_ps,
const T *threshold,
curandState_t *states,
const int p_num,
const int vocab_size,
const int bs,
const bool need_batch_random,
int *count_iter,
int *count_iter_begin,
const std::string &mode,
cudaStream_t stream) {
int BlockSize = GetBlockSize(vocab_size);
if (mode == "truncated") {
switch (BlockSize) {
FIXED_BLOCK_DIM(topp_sampling_ft<T, kBlockDim>
<<<bs, kBlockDim, 0, stream>>>(
sorted_probs, sorted_id, out_val, out_id, top_ps,
threshold, states, p_num, vocab_size,
need_batch_random, count_iter, count_iter_begin));
default:
PD_THROW("the input data shape has error in the topp_sampling kernel.");
<<<bs, kBlockDim, 0, stream>>>(sorted_probs,
sorted_id,
out_val,
out_id,
top_ps,
threshold,
states,
p_num,
vocab_size,
need_batch_random,
count_iter,
count_iter_begin));
default:
PD_THROW("the input data shape has error in the topp_sampling kernel.");
}
} else {
switch (BlockSize) {
FIXED_BLOCK_DIM(topp_sampling<T, kBlockDim><<<bs, kBlockDim, 0, stream>>>(
sorted_probs, sorted_id, out_val, out_id, top_ps, threshold, states,
p_num, vocab_size, need_batch_random, count_iter, count_iter_begin));
default:
PD_THROW("the input data shape has error in the topp_sampling kernel.");
FIXED_BLOCK_DIM(topp_sampling<T, kBlockDim>
<<<bs, kBlockDim, 0, stream>>>(sorted_probs,
sorted_id,
out_val,
out_id,
top_ps,
threshold,
states,
p_num,
vocab_size,
need_batch_random,
count_iter,
count_iter_begin));
default:
PD_THROW("the input data shape has error in the topp_sampling kernel.");
}
}
}
__global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed,
__global__ void air_topp_setup_kernel(curandState_t *state,
int64_t *seed,
const int bs) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
@@ -1204,8 +1332,10 @@ __global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed,
}
}
__global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed,
const uint64_t offset, const int bs,
__global__ void air_topp_setup_kernel(curandState_t *state,
const uint64_t seed,
const uint64_t offset,
const int bs,
const bool need_batch_random) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
@@ -1217,7 +1347,8 @@ __global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed,
}
}
template <typename T> __global__ void print_kernel(T *input, int size) {
template <typename T>
__global__ void print_kernel(T *input, int size) {
printf("[");
for (int i = 0; i < size; i++) {
if (i != size - 1) {
@@ -1229,11 +1360,14 @@ template <typename T> __global__ void print_kernel(T *input, int size) {
}
template <paddle::DataType D>
std::vector<paddle::Tensor>
LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
const paddle::optional<paddle::Tensor> &threshold,
const paddle::optional<paddle::Tensor> &topp_seed, int seed,
int k, const std::string &mode) {
std::vector<paddle::Tensor> LaunchTopPSampling(
const paddle::Tensor &x,
const paddle::Tensor &ps,
const paddle::optional<paddle::Tensor> &threshold,
const paddle::optional<paddle::Tensor> &topp_seed,
int seed,
int k,
const std::string &mode) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -1259,8 +1393,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
switch (BlockSize) {
FIXED_BLOCK_DIM(FillIndex<int64_t><<<bs, kBlockDim, 0, stream>>>(
inds_input.data<int64_t>(), bs, vocab_size));
default:
PD_THROW("the input data shape has error in the FillIndex kernel.");
default:
PD_THROW("the input data shape has error in the FillIndex kernel.");
}
int64_t *infer_seed =
topp_seed ? const_cast<int64_t *>(topp_seed.get().data<int64_t>())
@@ -1270,7 +1404,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
phi::Allocator::AllocationPtr curand_states_buf{nullptr};
curand_states_buf = phi::memory_utils::Alloc(
x.place(), bs * sizeof(curandState_t),
x.place(),
bs * sizeof(curandState_t),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
states = reinterpret_cast<curandState_t *>(curand_states_buf->ptr());
@@ -1290,11 +1425,11 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
auto seed_offset = gen_cuda->IncrementOffset(increment);
seed_now = seed_offset.first;
offset = seed_offset.second;
air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs,
need_batch_random);
air_topp_setup_kernel<<<1, 256, 0, stream>>>(
states, seed_now, offset, bs, need_batch_random);
} else {
air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs,
need_batch_random);
air_topp_setup_kernel<<<1, 256, 0, stream>>>(
states, seed_now, offset, bs, need_batch_random);
}
}
@@ -1313,13 +1448,21 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
DispatchKeMatrixTopPBeamTopK<DataType_, TopKMaxLength, TopPBeamTopK>(
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
reinterpret_cast<const DataType_ *>(threshold_data), states,
reinterpret_cast<DataType_ *>(ps_now.data<data_t>()), ids.data<int64_t>(),
reinterpret_cast<const DataType_ *>(threshold_data),
states,
reinterpret_cast<DataType_ *>(ps_now.data<data_t>()),
ids.data<int64_t>(),
reinterpret_cast<DataType_ *>(out.data<data_t>()),
topk_ids.data<int64_t>(),
reinterpret_cast<DataType_ *>(topk_scores.data<data_t>()), vocab_size,
count_iter.data<int>(), count_iter_begin.data<int>(), k, bs,
need_batch_random, mode, stream);
reinterpret_cast<DataType_ *>(topk_scores.data<data_t>()),
vocab_size,
count_iter.data<int>(),
count_iter_begin.data<int>(),
k,
bs,
need_batch_random,
mode,
stream);
static_assert(std::is_same<DataType_, float>::value,
"air_topp only supports float now!");
@@ -1328,7 +1471,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
constexpr int INIT_BLOCK_SIZE = 1024;
phi::Allocator::AllocationPtr counter_ptr{nullptr};
counter_ptr = phi::memory_utils::Alloc(
x.place(), bs * sizeof(Counter<DataType_>),
x.place(),
bs * sizeof(Counter<DataType_>),
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
Counter<DataType_> *counters =
reinterpret_cast<Counter<DataType_> *>(counter_ptr->ptr());
@@ -1346,67 +1490,93 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
paddle::empty({bs, buf_len}, paddle::DataType::INT32, x.place());
air_topp_init<float, BitsPerPass><<<bs, INIT_BLOCK_SIZE, 0, stream>>>(
counters, reinterpret_cast<float *>(histograms.data<data_t>()),
counters,
reinterpret_cast<float *>(histograms.data<data_t>()),
count_histograms.data<int32_t>(),
reinterpret_cast<const float *>(x.data<data_t>()),
reinterpret_cast<const float *>(ps.data<data_t>()), states, bs,
vocab_size, buf_len, numBuckets);
reinterpret_cast<const float *>(ps.data<data_t>()),
states,
bs,
vocab_size,
buf_len,
numBuckets);
constexpr int VecSize = 16 / sizeof(data_t);
// TODO: good block_num
const int max_block_num_vocab =
ceilDiv(vocab_size, SAMPLING_BLOCK_SIZE * VecSize);
auto kernel = air_topp_sampling<data_t, BitsPerPass, SAMPLING_BLOCK_SIZE,
numBuckets, 0>;
auto kernel = air_topp_sampling<data_t,
BitsPerPass,
SAMPLING_BLOCK_SIZE,
numBuckets,
0>;
const int dev_id = 0;
int sm_count;
int act_blocks_per_sm;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&act_blocks_per_sm, kernel,
SAMPLING_BLOCK_SIZE, 0);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&act_blocks_per_sm, kernel, SAMPLING_BLOCK_SIZE, 0);
assert(act_blocks_per_sm > 1);
const int block_per_wave = sm_count * act_blocks_per_sm;
const int block_num_vocab =
std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!!
std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!!
dim3 grid(block_num_vocab, bs);
constexpr int numPasses = calcNumPasses<data_t, BitsPerPass>();
for (int pass = 0; pass < numPasses; ++pass) {
if (pass == 0) {
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
air_topp_sampling<DataType_,
BitsPerPass,
SAMPLING_BLOCK_SIZE,
numBuckets,
0><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
counters,
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
count_histograms.data<int>(),
reinterpret_cast<DataType_ *>(out.data<data_t>()),
ids.data<int64_t>(),
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
id_buf1.data<int>(),
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
id_buf2.data<int>(), count_iter.data<int>(),
count_iter_begin.data<int>(), buf_len);
id_buf2.data<int>(),
count_iter.data<int>(),
count_iter_begin.data<int>(),
buf_len);
} else if (pass == 1) {
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
air_topp_sampling<DataType_,
BitsPerPass,
SAMPLING_BLOCK_SIZE,
numBuckets,
1><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
counters,
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
count_histograms.data<int>(),
reinterpret_cast<DataType_ *>(out.data<data_t>()),
ids.data<int64_t>(),
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
id_buf1.data<int>(),
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
id_buf2.data<int>(), count_iter.data<int>(),
count_iter_begin.data<int>(), buf_len);
id_buf2.data<int>(),
count_iter.data<int>(),
count_iter_begin.data<int>(),
buf_len);
} else if (pass == 2) {
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
air_topp_sampling<DataType_,
BitsPerPass,
SAMPLING_BLOCK_SIZE,
numBuckets,
2><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
counters,
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
count_histograms.data<int>(),
reinterpret_cast<DataType_ *>(out.data<data_t>()),
ids.data<int64_t>(),
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
id_buf1.data<int>(),
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
id_buf2.data<int>(), count_iter.data<int>(),
count_iter_begin.data<int>(), buf_len);
id_buf2.data<int>(),
count_iter.data<int>(),
count_iter_begin.data<int>(),
buf_len);
} else {
PD_THROW("pass must be 0,1 or 2!");
}
@@ -1414,52 +1584,60 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
return {out, ids};
}
std::vector<paddle::Tensor>
TopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
const paddle::optional<paddle::Tensor> &threshold,
const paddle::optional<paddle::Tensor> &topp_seed, int seed, int k,
const std::string &mode) {
std::vector<paddle::Tensor> TopPSampling(
const paddle::Tensor &x,
const paddle::Tensor &ps,
const paddle::optional<paddle::Tensor> &threshold,
const paddle::optional<paddle::Tensor> &topp_seed,
int seed,
int k,
const std::string &mode) {
switch (x.type()) {
case paddle::DataType::FLOAT32: {
return LaunchTopPSampling<paddle::DataType::FLOAT32>(
x, ps, threshold, topp_seed, seed, k, mode);
}
// case paddle::DataType::BFLOAT16: {
// return LaunchTopPSampling<paddle::DataType::BFLOAT16>(x, ps, threshold,
// topp_seed, seed, k, mode);
// }
// case paddle::DataType::FLOAT16: {
// return LaunchTopPSampling<paddle::DataType::FLOAT16>(x, ps, threshold,
// topp_seed, seed, k, mode);
// }
default: {
PD_THROW("NOT supported data type. Only support float. ");
break;
}
case paddle::DataType::FLOAT32: {
return LaunchTopPSampling<paddle::DataType::FLOAT32>(
x, ps, threshold, topp_seed, seed, k, mode);
}
// case paddle::DataType::BFLOAT16: {
// return LaunchTopPSampling<paddle::DataType::BFLOAT16>(x, ps,
// threshold, topp_seed, seed, k, mode);
// }
// case paddle::DataType::FLOAT16: {
// return LaunchTopPSampling<paddle::DataType::FLOAT16>(x, ps,
// threshold, topp_seed, seed, k, mode);
// }
default: {
PD_THROW("NOT supported data type. Only support float. ");
break;
}
}
}
std::vector<std::vector<int64_t>> GetTopPSamplingShape(
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &ps_shape,
const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &ps_shape,
const paddle::optional<std::vector<int64_t>> &threshold_shape,
const paddle::optional<std::vector<int64_t>> &topp_seed_shape, int seed,
const paddle::optional<std::vector<int64_t>> &topp_seed_shape,
int seed,
int k) {
int bs = x_shape[0];
int vocab_size = x_shape[1];
return {{bs, 1}, {bs, 1}};
}
std::vector<paddle::DataType>
GetTopPSamplingDtype(const paddle::DataType &x_dytpe,
const paddle::DataType &ps_dtype,
const paddle::optional<paddle::DataType> &threshold_dtype,
const paddle::optional<paddle::DataType> &topp_seed_dtype,
int seed, int k) {
std::vector<paddle::DataType> GetTopPSamplingDtype(
const paddle::DataType &x_dytpe,
const paddle::DataType &ps_dtype,
const paddle::optional<paddle::DataType> &threshold_dtype,
const paddle::optional<paddle::DataType> &topp_seed_dtype,
int seed,
int k) {
return {x_dytpe, paddle::DataType::INT64};
}
PD_BUILD_STATIC_OP(air_top_p_sampling)
.Inputs({"x", "ps", paddle::Optional("threshold"),
.Inputs({"x",
"ps",
paddle::Optional("threshold"),
paddle::Optional("topp_seed")})
.Outputs({"out", "ids"})
.Attrs({"seed: int", "k: int", "mode: std::string"})