From df67379bc397e3b2a5c9357c0a3cb6f9f0d0ac64 Mon Sep 17 00:00:00 2001 From: xiaozude Date: Tue, 9 Dec 2025 17:44:02 +0800 Subject: [PATCH] [Metax] modify wrapSize to WARP_SIZE (#5442) --- .../get_block_shape_and_split_kv_block.cu | 8 +- custom_ops/gpu_ops/get_padding_offset.cu | 6 +- .../sample_kernels/air_top_p_sampling.cu | 618 +++++++++++------- .../speculate_decoding/top_p_candidates.cu | 2 +- 4 files changed, 406 insertions(+), 228 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index 2b5c1fbc7..12aadec92 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -179,11 +179,11 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int num_rows_per_block, const int group_size) { // one block one warp - const int lane_id = threadIdx.x % warpSize; + const int lane_id = threadIdx.x % WARP_SIZE; int prev_offset = 0; // loop on warp tile:[base, base+32) - for (int base = 0; base < bsz; base += warpSize) { + for (int base = 0; base < bsz; base += WARP_SIZE) { const int bid = base + lane_id; // calculate loop_times for bid @@ -199,13 +199,13 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q, // prefix sum for each lane, get the start offset in this tile // inclusive scan int x = loop_times; - for (int offset = 1; offset < warpSize; offset <<= 1) { + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { int y = __shfl_up_sync(0xffffffff, x, offset); if (lane_id >= offset) x += y; } // exclusive prefix sum int bid_offset = x - loop_times; - int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1); + int tile_sum = __shfl_sync(0xffffffff, x, WARP_SIZE - 1); // write batch_ids and tile_ids_per_batch if (bid < bsz && loop_times > 0) { diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 6493941b7..60591d246 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -34,16 +34,16 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding, int cum_seq_len = 0; // compute sum of seq_lens[0,1,2,...,bi] - for (int i = lane_id; i < bi + 1; i += warpSize) { + for (int i = lane_id; i < bi + 1; i += WARP_SIZE) { cum_seq_len += seq_lens[i]; } - for (int offset = 1; offset < warpSize; offset <<= 1) { + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset); if (lane_id >= offset) cum_seq_len += tmp; } - cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1); + cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1); if (tid == 0) { cu_seqlens_q[bi + 1] = cum_seq_len; diff --git a/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu b/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu index ade1d74b5..99838e70d 100644 --- a/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu +++ b/custom_ops/gpu_ops/sample_kernels/air_top_p_sampling.cu @@ -30,8 +30,8 @@ #include #include -#include #include +#include #include "helper.h" #include "paddle/phi/backends/context_pool.h" @@ -40,20 +40,21 @@ #define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") +#define WARP_SIZE 32 #define FINAL_MASK 0xFFFFFFFF -#define FIXED_BLOCK_DIM_BASE(dim, ...) \ - case (dim): { \ - constexpr auto kBlockDim = (dim); \ - __VA_ARGS__; \ +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ } break -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) template @@ -123,7 +124,8 @@ __device__ T twiddleOut(typename cub::Traits::UnsignedBits bits, return reinterpret_cast(bits); } -template __host__ __device__ constexpr int calcNumBuckets() { +template +__host__ __device__ constexpr int calcNumBuckets() { return 1 << BitsPerPass; } @@ -161,12 +163,12 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) { if constexpr (numBuckets >= BlockSize) { static_assert(numBuckets % BlockSize == 0); int constexpr itemsPerThread = numBuckets / BlockSize; - typedef cub::BlockLoad - BlockLoad; - typedef cub::BlockStore - BlockStore; + typedef cub:: + BlockLoad + BlockLoad; + typedef cub:: + BlockStore + BlockStore; typedef cub::BlockScan BlockScan; __shared__ union { @@ -203,12 +205,19 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) { } template -__device__ __forceinline__ void -filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer, - int *out_idx_buffer, T *out_scores, int64_t *out_ids, - int previous_len, Counter *counter, T *histogram, - int *count_histogram, T *histogram_shm, - int *count_histogram_shm, const bool early_stop) { +__device__ __forceinline__ void filterAndHistogram(const T *in_buffer, + const int *in_idx_buffer, + T *out_buffer, + int *out_idx_buffer, + T *out_scores, + int64_t *out_ids, + int previous_len, + Counter *counter, + T *histogram, + int *count_histogram, + T *histogram_shm, + int *count_histogram_shm, + const bool early_stop) { // scan and filter constexpr int start_bit = calcStartBit(); const uint32_t mask = calcMask(); @@ -220,7 +229,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer, T array[VecSize]; } vec; for (int i = (blockIdx.x * blockDim.x + threadIdx.x); - i < ceilDiv(previous_len, VecSize); i += blockDim.x * gridDim.x) { + i < ceilDiv(previous_len, VecSize); + i += blockDim.x * gridDim.x) { vec.v = reinterpret_cast(in_buffer)[i]; if constexpr (Pass == 0) { #pragma unroll @@ -254,8 +264,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer, out_buffer[pos] = vec.array[j]; out_idx_buffer[pos] = in_idx_buffer ? in_idx_buffer[idx] : idx; } - int bucket = calcBucket(vec.array[j], start_bit, - mask, false); + int bucket = calcBucket( + vec.array[j], start_bit, mask, false); atomicAdd(histogram_shm + bucket, vec.array[j]); atomicAdd(count_histogram_shm + bucket, 1); } @@ -276,12 +286,18 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer, } template -__global__ void air_topp_sampling(Counter *counters, T *histograms, - int *count_histograms, T *out, int64_t *ids, - T *buf1, int *idx_buf1, T *buf2, - int *idx_buf2, int *count_iter, - int *count_iter_begin, const int buf_len) { - +__global__ void air_topp_sampling(Counter *counters, + T *histograms, + int *count_histograms, + T *out, + int64_t *ids, + T *buf1, + int *idx_buf1, + T *buf2, + int *idx_buf2, + int *count_iter, + int *count_iter_begin, + const int buf_len) { /*** * calc - filter - scan -find * TODO: calc - scan - find - filter @@ -352,10 +368,19 @@ __global__ void air_topp_sampling(Counter *counters, T *histograms, } __syncthreads(); - filterAndHistogram( - in_buf, in_idx_buf, out_buf, out_idx_buf, out, ids, previous_len, counter, - histogram, count_histogram, histogram_shm, count_histogram_shm, - early_stop); + filterAndHistogram(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + ids, + previous_len, + counter, + histogram, + count_histogram, + histogram_shm, + count_histogram_shm, + early_stop); __syncthreads(); __threadfence(); @@ -391,16 +416,16 @@ __global__ void air_topp_sampling(Counter *counters, T *histograms, __syncthreads(); // Acquire the summation of each 32 buckets for (int i = threadIdx.x; i < NumBuckets; i += BlockSize) { - reduce_store_async(warp, warpSum + i / WARP_SIZE, histogram[i], - cg::plus{}); + reduce_store_async( + warp, warpSum + i / WARP_SIZE, histogram[i], cg::plus{}); } __syncthreads(); // Acquire the summation of all the 2048 buckets if (threadIdx.x < WARP_SIZE) { - reduce_store_async(warp, blockSum, warpSum[threadIdx.x], - cg::plus{}); - reduce_update_async(warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], - cg::plus{}); + reduce_store_async( + warp, blockSum, warpSum[threadIdx.x], cg::plus{}); + reduce_update_async( + warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], cg::plus{}); } __syncthreads(); @@ -435,9 +460,9 @@ __global__ void air_topp_sampling(Counter *counters, T *histograms, } } counter->sum = - current_sum - prev; // how many values still are there to find - counter->len = count_histogram[targetIdx]; // cur - prev; // number of - // values in next pass + current_sum - prev; // how many values still are there to find + counter->len = count_histogram[targetIdx]; // cur - prev; // number of + // values in next pass typename cub::Traits::UnsignedBits bucket = targetIdx; int startBit = calcStartBit(); counter->kthValueBits |= bucket << startBit; @@ -473,10 +498,15 @@ __global__ void air_topp_sampling(Counter *counters, T *histograms, } template -__global__ void air_topp_init(Counter *counters, T *histograms, - int *count_histograms, const T *in, const T *ps, - curandState_t *curandstate, const int bsz, - const int vocab_size, const int buf_len, +__global__ void air_topp_init(Counter *counters, + T *histograms, + int *count_histograms, + const T *in, + const T *ps, + curandState_t *curandstate, + const int bsz, + const int vocab_size, + const int buf_len, const int num_buckets) { const int bid = blockIdx.x; const int tid = threadIdx.x; @@ -517,7 +547,8 @@ struct SegmentOffsetIter { int num_cols_; }; -template struct Pair { +template +struct Pair { __device__ __forceinline__ Pair() {} __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} @@ -557,7 +588,8 @@ template struct Pair { inline int div_up(int a, int n) { return (a + n - 1) / n; } template -__device__ __forceinline__ void AddTo(Pair topk[], const Pair &p, +__device__ __forceinline__ void AddTo(Pair topk[], + const Pair &p, int beam_size) { for (int k = beam_size - 2; k >= 0; k--) { if (topk[k] < p) { @@ -571,8 +603,8 @@ __device__ __forceinline__ void AddTo(Pair topk[], const Pair &p, } template -__device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, - int dim, int beam_size) { +__device__ __forceinline__ void GetTopK( + Pair topk[], const T *src, int idx, int dim, int beam_size) { while (idx < dim) { if (topk[beam_size - 1] < src[idx]) { Pair tmp(src[idx], idx); @@ -583,8 +615,11 @@ __device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, } template -__device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, - int dim, const Pair &max, +__device__ __forceinline__ void GetTopK(Pair topk[], + const T *src, + int idx, + int dim, + const Pair &max, int beam_size) { while (idx < dim) { if (topk[beam_size - 1] < src[idx]) { @@ -598,10 +633,15 @@ __device__ __forceinline__ void GetTopK(Pair topk[], const T *src, int idx, } template -__device__ __forceinline__ void -ThreadGetTopK(Pair topk[], int *beam, int beam_size, const T *src, - bool *firstStep, bool *is_empty, Pair *max, int dim, - const int tid) { +__device__ __forceinline__ void ThreadGetTopK(Pair topk[], + int *beam, + int beam_size, + const T *src, + bool *firstStep, + bool *is_empty, + Pair *max, + int dim, + const int tid) { if (*beam > 0) { int length = (*beam) < beam_size ? *beam : beam_size; if (*firstStep) { @@ -616,22 +656,20 @@ ThreadGetTopK(Pair topk[], int *beam, int beam_size, const T *src, } } if (!(*is_empty)) { - GetTopK(topk + MaxLength - *beam, src, tid, dim, *max, - length); + GetTopK( + topk + MaxLength - *beam, src, tid, dim, *max, length); } } *max = topk[MaxLength - 1]; - if ((*max).id == -1) - *is_empty = true; + if ((*max).id == -1) *is_empty = true; *beam = 0; } } template -__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, - int delta, - int width = warpSize) { +__forceinline__ __device__ T +CudaShuffleDownSync(unsigned mask, T val, int delta, int width = WARP_SIZE) { return __shfl_down_sync(mask, val, static_cast(delta), width); } @@ -650,9 +688,15 @@ __forceinline__ __device__ Pair WarpReduce(Pair input) { } template -__device__ __forceinline__ void -BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int *beam, - int *k, int *count, const int tid, const int wid, const int lane) { +__device__ __forceinline__ void BlockReduce(Pair shared_max[], + Pair topk[], + Pair beam_max[], + int *beam, + int *k, + int *count, + const int tid, + const int wid, + const int lane) { while (true) { __syncthreads(); Pair input_now = topk[0]; @@ -667,8 +711,7 @@ BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int *beam, : Pair(std::numeric_limits::min(), -1); if (wid == 0) { input_now = WarpReduce(input_now); - if (lane == 0) - shared_max[0] = input_now; + if (lane == 0) shared_max[0] = input_now; } __syncthreads(); if (tid == 0) { @@ -679,8 +722,7 @@ BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int *beam, if (tid == tid_max) { (*beam)++; } - if (--(*k) == 0) - break; + if (--(*k) == 0) break; __syncthreads(); if (tid == tid_max) { @@ -690,8 +732,7 @@ BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int *beam, } if (MaxLength < 5) { - if (*beam >= MaxLength) - break; + if (*beam >= MaxLength) break; } else { unsigned mask = 0u; mask = __ballot_sync(FINAL_MASK, true); @@ -721,13 +762,18 @@ __device__ inline T exponential_transform(T val, T lambda) { } template -__global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold, - curandState_t *states, T *top_ps, - int64_t *out_id, // topk id - T *out_val, // topk val - int64_t *topk_ids, T *topk_scores, - int vocab_size, int *count_iter, - int *count_iter_begin, const int k, +__global__ void KeMatrixTopPBeamTopK(const T *src, + const T *threshold, + curandState_t *states, + T *top_ps, + int64_t *out_id, // topk id + T *out_val, // topk val + int64_t *topk_ids, + T *topk_scores, + int vocab_size, + int *count_iter, + int *count_iter_begin, + const int k, const bool need_batch_random) { const int tid = threadIdx.x; const int wid = tid / 32; @@ -761,11 +807,17 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold, } while (top_num) { - ThreadGetTopK(topk, &beam, TopPBeamTopK, - src + offset, &firststep, &is_empty, - &max, vocab_size, tid); - BlockReduce(shared_max, topk, beam_max, &beam, - &top_num, &count, tid, wid, lane); + ThreadGetTopK(topk, + &beam, + TopPBeamTopK, + src + offset, + &firststep, + &is_empty, + &max, + vocab_size, + tid); + BlockReduce( + shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); } if (tid == 0) { // printf("offset: %d\n", (int)seed_offset); @@ -817,13 +869,18 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold, } template -__global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold, - curandState_t *states, T *top_ps, - int64_t *out_id, // topk id - T *out_val, // topk val - int64_t *topk_ids, T *topk_scores, - int vocab_size, int *count_iter, - int *count_iter_begin, const int k, +__global__ void KeMatrixTopPBeamTopKFt(const T *src, + const T *threshold, + curandState_t *states, + T *top_ps, + int64_t *out_id, // topk id + T *out_val, // topk val + int64_t *topk_ids, + T *topk_scores, + int vocab_size, + int *count_iter, + int *count_iter_begin, + const int k, const bool need_batch_random) { const int tid = threadIdx.x; const int wid = tid / 32; @@ -856,11 +913,17 @@ __global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold, } while (top_num) { - ThreadGetTopK(topk, &beam, TopPBeamTopK, - src + bid * vocab_size, &firststep, - &is_empty, &max, vocab_size, tid); - BlockReduce(shared_max, topk, beam_max, &beam, - &top_num, &count, tid, wid, lane); + ThreadGetTopK(topk, + &beam, + TopPBeamTopK, + src + bid * vocab_size, + &firststep, + &is_empty, + &max, + vocab_size, + tid); + BlockReduce( + shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); } if (tid == 0) { count_iter_begin[bid] = count_iter[bid]; @@ -925,14 +988,20 @@ __global__ void FillIndex(T *indices, T num_rows, T num_cols) { } template -void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold, - curandState_t *states, T *top_ps, - int64_t *out_id, // topk id - T *out_val, // topk val - int64_t *topk_ids, T *topk_scores, - int vocab_size, int *count_iter, - int *count_iter_begin, const int k, - const int bs, const bool need_batch_random, +void DispatchKeMatrixTopPBeamTopK(const T *src, + const T *threshold, + curandState_t *states, + T *top_ps, + int64_t *out_id, // topk id + T *out_val, // topk val + int64_t *topk_ids, + T *topk_scores, + int vocab_size, + int *count_iter, + int *count_iter_begin, + const int k, + const int bs, + const bool need_batch_random, const std::string &mode, cudaStream_t stream) { int BlockSize = GetBlockSize(vocab_size); @@ -940,23 +1009,43 @@ void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold, switch (BlockSize) { FIXED_BLOCK_DIM( KeMatrixTopPBeamTopKFt - <<>>( - src, threshold, states, top_ps, out_id, out_val, topk_ids, - topk_scores, vocab_size, count_iter, count_iter_begin, k, - need_batch_random)); - default: - PD_THROW("the input data shape has error in the topp_beam_topk kernel."); + <<>>(src, + threshold, + states, + top_ps, + out_id, + out_val, + topk_ids, + topk_scores, + vocab_size, + count_iter, + count_iter_begin, + k, + need_batch_random)); + default: + PD_THROW( + "the input data shape has error in the topp_beam_topk kernel."); } } else { switch (BlockSize) { FIXED_BLOCK_DIM( KeMatrixTopPBeamTopK - <<>>( - src, threshold, states, top_ps, out_id, out_val, topk_ids, - topk_scores, vocab_size, count_iter, count_iter_begin, k, - need_batch_random)); - default: - PD_THROW("the input data shape has error in the topp_beam_topk kernel."); + <<>>(src, + threshold, + states, + top_ps, + out_id, + out_val, + topk_ids, + topk_scores, + vocab_size, + count_iter, + count_iter_begin, + k, + need_batch_random)); + default: + PD_THROW( + "the input data shape has error in the topp_beam_topk kernel."); } } } @@ -978,11 +1067,17 @@ struct BlockPrefixCallbackOp { }; template -__global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val, - int64_t *out_id, const T *top_ps, - const T *threshold, curandState_t *states, - const int p_num, const int vocab_size, - const bool need_batch_random, int *count_iter, +__global__ void topp_sampling(T *sorted_probs, + int64_t *sorted_id, + T *out_val, + int64_t *out_id, + const T *top_ps, + const T *threshold, + curandState_t *states, + const int p_num, + const int vocab_size, + const bool need_batch_random, + int *count_iter, int *count_iter_begin) { __shared__ int stop_shared; const int tid = threadIdx.x; @@ -1063,11 +1158,17 @@ __global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val, } template -__global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id, - T *out_val, int64_t *out_id, const T *top_ps, - const T *threshold, curandState_t *states, - const int p_num, const int vocab_size, - const bool need_batch_random, int *count_iter, +__global__ void topp_sampling_ft(T *sorted_probs, + int64_t *sorted_id, + T *out_val, + int64_t *out_id, + const T *top_ps, + const T *threshold, + curandState_t *states, + const int p_num, + const int vocab_size, + const bool need_batch_random, + int *count_iter, int *count_iter_begin) { __shared__ int stop_shared; __shared__ float rand_p; @@ -1146,7 +1247,7 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id, } } if (!skip) { - int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0 + int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0 if (lane_id == active_lane_id) { float val = static_cast(sorted_probs[offset + i_activate]); if (val < threshold_now) { @@ -1167,36 +1268,63 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id, } template -void DispatchTopPSampling(T *sorted_probs, int64_t *sorted_id, T *out_val, - int64_t *out_id, const T *top_ps, const T *threshold, - curandState_t *states, const int p_num, - const int vocab_size, const int bs, - const bool need_batch_random, int *count_iter, - int *count_iter_begin, const std::string &mode, +void DispatchTopPSampling(T *sorted_probs, + int64_t *sorted_id, + T *out_val, + int64_t *out_id, + const T *top_ps, + const T *threshold, + curandState_t *states, + const int p_num, + const int vocab_size, + const int bs, + const bool need_batch_random, + int *count_iter, + int *count_iter_begin, + const std::string &mode, cudaStream_t stream) { int BlockSize = GetBlockSize(vocab_size); if (mode == "truncated") { switch (BlockSize) { FIXED_BLOCK_DIM(topp_sampling_ft - <<>>( - sorted_probs, sorted_id, out_val, out_id, top_ps, - threshold, states, p_num, vocab_size, - need_batch_random, count_iter, count_iter_begin)); - default: - PD_THROW("the input data shape has error in the topp_sampling kernel."); + <<>>(sorted_probs, + sorted_id, + out_val, + out_id, + top_ps, + threshold, + states, + p_num, + vocab_size, + need_batch_random, + count_iter, + count_iter_begin)); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); } } else { switch (BlockSize) { - FIXED_BLOCK_DIM(topp_sampling<<>>( - sorted_probs, sorted_id, out_val, out_id, top_ps, threshold, states, - p_num, vocab_size, need_batch_random, count_iter, count_iter_begin)); - default: - PD_THROW("the input data shape has error in the topp_sampling kernel."); + FIXED_BLOCK_DIM(topp_sampling + <<>>(sorted_probs, + sorted_id, + out_val, + out_id, + top_ps, + threshold, + states, + p_num, + vocab_size, + need_batch_random, + count_iter, + count_iter_begin)); + default: + PD_THROW("the input data shape has error in the topp_sampling kernel."); } } } -__global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed, +__global__ void air_topp_setup_kernel(curandState_t *state, + int64_t *seed, const int bs) { int idx = blockIdx.x * blockDim.x + threadIdx.x; for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { @@ -1204,8 +1332,10 @@ __global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed, } } -__global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed, - const uint64_t offset, const int bs, +__global__ void air_topp_setup_kernel(curandState_t *state, + const uint64_t seed, + const uint64_t offset, + const int bs, const bool need_batch_random) { int idx = blockIdx.x * blockDim.x + threadIdx.x; for (int i = idx; i < bs; i += gridDim.x * blockDim.x) { @@ -1217,7 +1347,8 @@ __global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed, } } -template __global__ void print_kernel(T *input, int size) { +template +__global__ void print_kernel(T *input, int size) { printf("["); for (int i = 0; i < size; i++) { if (i != size - 1) { @@ -1229,11 +1360,14 @@ template __global__ void print_kernel(T *input, int size) { } template -std::vector -LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, - const paddle::optional &threshold, - const paddle::optional &topp_seed, int seed, - int k, const std::string &mode) { +std::vector LaunchTopPSampling( + const paddle::Tensor &x, + const paddle::Tensor &ps, + const paddle::optional &threshold, + const paddle::optional &topp_seed, + int seed, + int k, + const std::string &mode) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -1259,8 +1393,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, switch (BlockSize) { FIXED_BLOCK_DIM(FillIndex<<>>( inds_input.data(), bs, vocab_size)); - default: - PD_THROW("the input data shape has error in the FillIndex kernel."); + default: + PD_THROW("the input data shape has error in the FillIndex kernel."); } int64_t *infer_seed = topp_seed ? const_cast(topp_seed.get().data()) @@ -1270,7 +1404,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, phi::Allocator::AllocationPtr curand_states_buf{nullptr}; curand_states_buf = phi::memory_utils::Alloc( - x.place(), bs * sizeof(curandState_t), + x.place(), + bs * sizeof(curandState_t), phi::Stream(reinterpret_cast(stream))); states = reinterpret_cast(curand_states_buf->ptr()); @@ -1290,11 +1425,11 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, auto seed_offset = gen_cuda->IncrementOffset(increment); seed_now = seed_offset.first; offset = seed_offset.second; - air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs, - need_batch_random); + air_topp_setup_kernel<<<1, 256, 0, stream>>>( + states, seed_now, offset, bs, need_batch_random); } else { - air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs, - need_batch_random); + air_topp_setup_kernel<<<1, 256, 0, stream>>>( + states, seed_now, offset, bs, need_batch_random); } } @@ -1313,13 +1448,21 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, DispatchKeMatrixTopPBeamTopK( reinterpret_cast(x.data()), - reinterpret_cast(threshold_data), states, - reinterpret_cast(ps_now.data()), ids.data(), + reinterpret_cast(threshold_data), + states, + reinterpret_cast(ps_now.data()), + ids.data(), reinterpret_cast(out.data()), topk_ids.data(), - reinterpret_cast(topk_scores.data()), vocab_size, - count_iter.data(), count_iter_begin.data(), k, bs, - need_batch_random, mode, stream); + reinterpret_cast(topk_scores.data()), + vocab_size, + count_iter.data(), + count_iter_begin.data(), + k, + bs, + need_batch_random, + mode, + stream); static_assert(std::is_same::value, "air_topp only supports float now!"); @@ -1328,7 +1471,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, constexpr int INIT_BLOCK_SIZE = 1024; phi::Allocator::AllocationPtr counter_ptr{nullptr}; counter_ptr = phi::memory_utils::Alloc( - x.place(), bs * sizeof(Counter), + x.place(), + bs * sizeof(Counter), phi::Stream(reinterpret_cast(stream))); Counter *counters = reinterpret_cast *>(counter_ptr->ptr()); @@ -1346,67 +1490,93 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, paddle::empty({bs, buf_len}, paddle::DataType::INT32, x.place()); air_topp_init<<>>( - counters, reinterpret_cast(histograms.data()), + counters, + reinterpret_cast(histograms.data()), count_histograms.data(), reinterpret_cast(x.data()), - reinterpret_cast(ps.data()), states, bs, - vocab_size, buf_len, numBuckets); + reinterpret_cast(ps.data()), + states, + bs, + vocab_size, + buf_len, + numBuckets); constexpr int VecSize = 16 / sizeof(data_t); // TODO: good block_num const int max_block_num_vocab = ceilDiv(vocab_size, SAMPLING_BLOCK_SIZE * VecSize); - auto kernel = air_topp_sampling; + auto kernel = air_topp_sampling; const int dev_id = 0; int sm_count; int act_blocks_per_sm; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&act_blocks_per_sm, kernel, - SAMPLING_BLOCK_SIZE, 0); + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &act_blocks_per_sm, kernel, SAMPLING_BLOCK_SIZE, 0); assert(act_blocks_per_sm > 1); const int block_per_wave = sm_count * act_blocks_per_sm; const int block_num_vocab = - std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!! + std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!! dim3 grid(block_num_vocab, bs); constexpr int numPasses = calcNumPasses(); for (int pass = 0; pass < numPasses; ++pass) { if (pass == 0) { - air_topp_sampling<<>>( - counters, reinterpret_cast(histograms.data()), + counters, + reinterpret_cast(histograms.data()), count_histograms.data(), reinterpret_cast(out.data()), ids.data(), reinterpret_cast(buf1.data()), id_buf1.data(), reinterpret_cast(buf2.data()), - id_buf2.data(), count_iter.data(), - count_iter_begin.data(), buf_len); + id_buf2.data(), + count_iter.data(), + count_iter_begin.data(), + buf_len); } else if (pass == 1) { - air_topp_sampling<<>>( - counters, reinterpret_cast(histograms.data()), + counters, + reinterpret_cast(histograms.data()), count_histograms.data(), reinterpret_cast(out.data()), ids.data(), reinterpret_cast(buf1.data()), id_buf1.data(), reinterpret_cast(buf2.data()), - id_buf2.data(), count_iter.data(), - count_iter_begin.data(), buf_len); + id_buf2.data(), + count_iter.data(), + count_iter_begin.data(), + buf_len); } else if (pass == 2) { - air_topp_sampling<<>>( - counters, reinterpret_cast(histograms.data()), + counters, + reinterpret_cast(histograms.data()), count_histograms.data(), reinterpret_cast(out.data()), ids.data(), reinterpret_cast(buf1.data()), id_buf1.data(), reinterpret_cast(buf2.data()), - id_buf2.data(), count_iter.data(), - count_iter_begin.data(), buf_len); + id_buf2.data(), + count_iter.data(), + count_iter_begin.data(), + buf_len); } else { PD_THROW("pass must be 0,1 or 2!"); } @@ -1414,52 +1584,60 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, return {out, ids}; } -std::vector -TopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps, - const paddle::optional &threshold, - const paddle::optional &topp_seed, int seed, int k, - const std::string &mode) { +std::vector TopPSampling( + const paddle::Tensor &x, + const paddle::Tensor &ps, + const paddle::optional &threshold, + const paddle::optional &topp_seed, + int seed, + int k, + const std::string &mode) { switch (x.type()) { - case paddle::DataType::FLOAT32: { - return LaunchTopPSampling( - x, ps, threshold, topp_seed, seed, k, mode); - } - // case paddle::DataType::BFLOAT16: { - // return LaunchTopPSampling(x, ps, threshold, - // topp_seed, seed, k, mode); - // } - // case paddle::DataType::FLOAT16: { - // return LaunchTopPSampling(x, ps, threshold, - // topp_seed, seed, k, mode); - // } - default: { - PD_THROW("NOT supported data type. Only support float. "); - break; - } + case paddle::DataType::FLOAT32: { + return LaunchTopPSampling( + x, ps, threshold, topp_seed, seed, k, mode); + } + // case paddle::DataType::BFLOAT16: { + // return LaunchTopPSampling(x, ps, + // threshold, topp_seed, seed, k, mode); + // } + // case paddle::DataType::FLOAT16: { + // return LaunchTopPSampling(x, ps, + // threshold, topp_seed, seed, k, mode); + // } + default: { + PD_THROW("NOT supported data type. Only support float. "); + break; + } } } std::vector> GetTopPSamplingShape( - const std::vector &x_shape, const std::vector &ps_shape, + const std::vector &x_shape, + const std::vector &ps_shape, const paddle::optional> &threshold_shape, - const paddle::optional> &topp_seed_shape, int seed, + const paddle::optional> &topp_seed_shape, + int seed, int k) { int bs = x_shape[0]; int vocab_size = x_shape[1]; return {{bs, 1}, {bs, 1}}; } -std::vector -GetTopPSamplingDtype(const paddle::DataType &x_dytpe, - const paddle::DataType &ps_dtype, - const paddle::optional &threshold_dtype, - const paddle::optional &topp_seed_dtype, - int seed, int k) { +std::vector GetTopPSamplingDtype( + const paddle::DataType &x_dytpe, + const paddle::DataType &ps_dtype, + const paddle::optional &threshold_dtype, + const paddle::optional &topp_seed_dtype, + int seed, + int k) { return {x_dytpe, paddle::DataType::INT64}; } PD_BUILD_STATIC_OP(air_top_p_sampling) - .Inputs({"x", "ps", paddle::Optional("threshold"), + .Inputs({"x", + "ps", + paddle::Optional("threshold"), paddle::Optional("topp_seed")}) .Outputs({"out", "ids"}) .Attrs({"seed: int", "k: int", "mode: std::string"}) diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index bcafd45d0..e5ba39a1f 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -18,7 +18,7 @@ template __forceinline__ __device__ T -CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { +CudaShuffleDownSync(unsigned mask, T val, int delta, int width = WARP_SIZE) { return __shfl_down_sync(mask, val, static_cast(delta), width); }