mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] modify wrapSize to WARP_SIZE (#5442)
This commit is contained in:
@@ -179,11 +179,11 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int num_rows_per_block,
|
||||
const int group_size) {
|
||||
// one block one warp
|
||||
const int lane_id = threadIdx.x % warpSize;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
int prev_offset = 0;
|
||||
|
||||
// loop on warp tile:[base, base+32)
|
||||
for (int base = 0; base < bsz; base += warpSize) {
|
||||
for (int base = 0; base < bsz; base += WARP_SIZE) {
|
||||
const int bid = base + lane_id;
|
||||
|
||||
// calculate loop_times for bid
|
||||
@@ -199,13 +199,13 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
// prefix sum for each lane, get the start offset in this tile
|
||||
// inclusive scan
|
||||
int x = loop_times;
|
||||
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
int y = __shfl_up_sync(0xffffffff, x, offset);
|
||||
if (lane_id >= offset) x += y;
|
||||
}
|
||||
// exclusive prefix sum
|
||||
int bid_offset = x - loop_times;
|
||||
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
|
||||
int tile_sum = __shfl_sync(0xffffffff, x, WARP_SIZE - 1);
|
||||
|
||||
// write batch_ids and tile_ids_per_batch
|
||||
if (bid < bsz && loop_times > 0) {
|
||||
|
||||
@@ -34,16 +34,16 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
|
||||
int cum_seq_len = 0;
|
||||
|
||||
// compute sum of seq_lens[0,1,2,...,bi]
|
||||
for (int i = lane_id; i < bi + 1; i += warpSize) {
|
||||
for (int i = lane_id; i < bi + 1; i += WARP_SIZE) {
|
||||
cum_seq_len += seq_lens[i];
|
||||
}
|
||||
|
||||
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset);
|
||||
if (lane_id >= offset) cum_seq_len += tmp;
|
||||
}
|
||||
|
||||
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1);
|
||||
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1);
|
||||
|
||||
if (tid == 0) {
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
|
||||
@@ -30,8 +30,8 @@
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda/atomic>
|
||||
#include <curand_kernel.h>
|
||||
#include <cuda/atomic>
|
||||
|
||||
#include "helper.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
@@ -40,20 +40,21 @@
|
||||
|
||||
#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define FINAL_MASK 0xFFFFFFFF
|
||||
|
||||
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
||||
case (dim): { \
|
||||
constexpr auto kBlockDim = (dim); \
|
||||
__VA_ARGS__; \
|
||||
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
||||
case (dim): { \
|
||||
constexpr auto kBlockDim = (dim); \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
|
||||
#define FIXED_BLOCK_DIM(...) \
|
||||
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
||||
#define FIXED_BLOCK_DIM(...) \
|
||||
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
|
||||
|
||||
template <typename T, typename IdxT = int, typename AccT = T>
|
||||
@@ -123,7 +124,8 @@ __device__ T twiddleOut(typename cub::Traits<T>::UnsignedBits bits,
|
||||
return reinterpret_cast<T &>(bits);
|
||||
}
|
||||
|
||||
template <int BitsPerPass> __host__ __device__ constexpr int calcNumBuckets() {
|
||||
template <int BitsPerPass>
|
||||
__host__ __device__ constexpr int calcNumBuckets() {
|
||||
return 1 << BitsPerPass;
|
||||
}
|
||||
|
||||
@@ -161,12 +163,12 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) {
|
||||
if constexpr (numBuckets >= BlockSize) {
|
||||
static_assert(numBuckets % BlockSize == 0);
|
||||
int constexpr itemsPerThread = numBuckets / BlockSize;
|
||||
typedef cub::BlockLoad<IdxT, BlockSize, itemsPerThread,
|
||||
cub::BLOCK_LOAD_TRANSPOSE>
|
||||
BlockLoad;
|
||||
typedef cub::BlockStore<IdxT, BlockSize, itemsPerThread,
|
||||
cub::BLOCK_STORE_TRANSPOSE>
|
||||
BlockStore;
|
||||
typedef cub::
|
||||
BlockLoad<IdxT, BlockSize, itemsPerThread, cub::BLOCK_LOAD_TRANSPOSE>
|
||||
BlockLoad;
|
||||
typedef cub::
|
||||
BlockStore<IdxT, BlockSize, itemsPerThread, cub::BLOCK_STORE_TRANSPOSE>
|
||||
BlockStore;
|
||||
typedef cub::BlockScan<IdxT, BlockSize> BlockScan;
|
||||
|
||||
__shared__ union {
|
||||
@@ -203,12 +205,19 @@ __device__ void scan(IdxT volatile *histogram, IdxT *histogramOut) {
|
||||
}
|
||||
|
||||
template <typename T, int BitsPerPass, int NumBuckets, int Pass>
|
||||
__device__ __forceinline__ void
|
||||
filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
|
||||
int *out_idx_buffer, T *out_scores, int64_t *out_ids,
|
||||
int previous_len, Counter<T> *counter, T *histogram,
|
||||
int *count_histogram, T *histogram_shm,
|
||||
int *count_histogram_shm, const bool early_stop) {
|
||||
__device__ __forceinline__ void filterAndHistogram(const T *in_buffer,
|
||||
const int *in_idx_buffer,
|
||||
T *out_buffer,
|
||||
int *out_idx_buffer,
|
||||
T *out_scores,
|
||||
int64_t *out_ids,
|
||||
int previous_len,
|
||||
Counter<T> *counter,
|
||||
T *histogram,
|
||||
int *count_histogram,
|
||||
T *histogram_shm,
|
||||
int *count_histogram_shm,
|
||||
const bool early_stop) {
|
||||
// scan and filter
|
||||
constexpr int start_bit = calcStartBit<T, BitsPerPass, Pass>();
|
||||
const uint32_t mask = calcMask<T, BitsPerPass, Pass>();
|
||||
@@ -220,7 +229,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
|
||||
T array[VecSize];
|
||||
} vec;
|
||||
for (int i = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
i < ceilDiv(previous_len, VecSize); i += blockDim.x * gridDim.x) {
|
||||
i < ceilDiv(previous_len, VecSize);
|
||||
i += blockDim.x * gridDim.x) {
|
||||
vec.v = reinterpret_cast<const VecT *>(in_buffer)[i];
|
||||
if constexpr (Pass == 0) {
|
||||
#pragma unroll
|
||||
@@ -254,8 +264,8 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
|
||||
out_buffer[pos] = vec.array[j];
|
||||
out_idx_buffer[pos] = in_idx_buffer ? in_idx_buffer[idx] : idx;
|
||||
}
|
||||
int bucket = calcBucket<T, BitsPerPass>(vec.array[j], start_bit,
|
||||
mask, false);
|
||||
int bucket = calcBucket<T, BitsPerPass>(
|
||||
vec.array[j], start_bit, mask, false);
|
||||
atomicAdd(histogram_shm + bucket, vec.array[j]);
|
||||
atomicAdd(count_histogram_shm + bucket, 1);
|
||||
}
|
||||
@@ -276,12 +286,18 @@ filterAndHistogram(const T *in_buffer, const int *in_idx_buffer, T *out_buffer,
|
||||
}
|
||||
|
||||
template <typename T, int BitsPerPass, int BlockSize, int NumBuckets, int Pass>
|
||||
__global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
int *count_histograms, T *out, int64_t *ids,
|
||||
T *buf1, int *idx_buf1, T *buf2,
|
||||
int *idx_buf2, int *count_iter,
|
||||
int *count_iter_begin, const int buf_len) {
|
||||
|
||||
__global__ void air_topp_sampling(Counter<T> *counters,
|
||||
T *histograms,
|
||||
int *count_histograms,
|
||||
T *out,
|
||||
int64_t *ids,
|
||||
T *buf1,
|
||||
int *idx_buf1,
|
||||
T *buf2,
|
||||
int *idx_buf2,
|
||||
int *count_iter,
|
||||
int *count_iter_begin,
|
||||
const int buf_len) {
|
||||
/***
|
||||
* calc - filter - scan -find
|
||||
* TODO: calc - scan - find - filter
|
||||
@@ -352,10 +368,19 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
filterAndHistogram<T, BitsPerPass, NumBuckets, Pass>(
|
||||
in_buf, in_idx_buf, out_buf, out_idx_buf, out, ids, previous_len, counter,
|
||||
histogram, count_histogram, histogram_shm, count_histogram_shm,
|
||||
early_stop);
|
||||
filterAndHistogram<T, BitsPerPass, NumBuckets, Pass>(in_buf,
|
||||
in_idx_buf,
|
||||
out_buf,
|
||||
out_idx_buf,
|
||||
out,
|
||||
ids,
|
||||
previous_len,
|
||||
counter,
|
||||
histogram,
|
||||
count_histogram,
|
||||
histogram_shm,
|
||||
count_histogram_shm,
|
||||
early_stop);
|
||||
__syncthreads();
|
||||
__threadfence();
|
||||
|
||||
@@ -391,16 +416,16 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
__syncthreads();
|
||||
// Acquire the summation of each 32 buckets
|
||||
for (int i = threadIdx.x; i < NumBuckets; i += BlockSize) {
|
||||
reduce_store_async(warp, warpSum + i / WARP_SIZE, histogram[i],
|
||||
cg::plus<float>{});
|
||||
reduce_store_async(
|
||||
warp, warpSum + i / WARP_SIZE, histogram[i], cg::plus<float>{});
|
||||
}
|
||||
__syncthreads();
|
||||
// Acquire the summation of all the 2048 buckets
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
reduce_store_async(warp, blockSum, warpSum[threadIdx.x],
|
||||
cg::plus<float>{});
|
||||
reduce_update_async(warp, blockSum, warpSum[threadIdx.x + WARP_SIZE],
|
||||
cg::plus<float>{});
|
||||
reduce_store_async(
|
||||
warp, blockSum, warpSum[threadIdx.x], cg::plus<float>{});
|
||||
reduce_update_async(
|
||||
warp, blockSum, warpSum[threadIdx.x + WARP_SIZE], cg::plus<float>{});
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -435,9 +460,9 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
}
|
||||
}
|
||||
counter->sum =
|
||||
current_sum - prev; // how many values still are there to find
|
||||
counter->len = count_histogram[targetIdx]; // cur - prev; // number of
|
||||
// values in next pass
|
||||
current_sum - prev; // how many values still are there to find
|
||||
counter->len = count_histogram[targetIdx]; // cur - prev; // number of
|
||||
// values in next pass
|
||||
typename cub::Traits<T>::UnsignedBits bucket = targetIdx;
|
||||
int startBit = calcStartBit<T, BitsPerPass, Pass>();
|
||||
counter->kthValueBits |= bucket << startBit;
|
||||
@@ -473,10 +498,15 @@ __global__ void air_topp_sampling(Counter<T> *counters, T *histograms,
|
||||
}
|
||||
|
||||
template <typename T, int BitsPerPass>
|
||||
__global__ void air_topp_init(Counter<T> *counters, T *histograms,
|
||||
int *count_histograms, const T *in, const T *ps,
|
||||
curandState_t *curandstate, const int bsz,
|
||||
const int vocab_size, const int buf_len,
|
||||
__global__ void air_topp_init(Counter<T> *counters,
|
||||
T *histograms,
|
||||
int *count_histograms,
|
||||
const T *in,
|
||||
const T *ps,
|
||||
curandState_t *curandstate,
|
||||
const int bsz,
|
||||
const int vocab_size,
|
||||
const int buf_len,
|
||||
const int num_buckets) {
|
||||
const int bid = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
@@ -517,7 +547,8 @@ struct SegmentOffsetIter {
|
||||
int num_cols_;
|
||||
};
|
||||
|
||||
template <typename T> struct Pair {
|
||||
template <typename T>
|
||||
struct Pair {
|
||||
__device__ __forceinline__ Pair() {}
|
||||
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
|
||||
|
||||
@@ -557,7 +588,8 @@ template <typename T> struct Pair {
|
||||
inline int div_up(int a, int n) { return (a + n - 1) / n; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T> &p,
|
||||
__device__ __forceinline__ void AddTo(Pair<T> topk[],
|
||||
const Pair<T> &p,
|
||||
int beam_size) {
|
||||
for (int k = beam_size - 2; k >= 0; k--) {
|
||||
if (topk[k] < p) {
|
||||
@@ -571,8 +603,8 @@ __device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T> &p,
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
|
||||
int dim, int beam_size) {
|
||||
__device__ __forceinline__ void GetTopK(
|
||||
Pair<T> topk[], const T *src, int idx, int dim, int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
@@ -583,8 +615,11 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
|
||||
int dim, const Pair<T> &max,
|
||||
__device__ __forceinline__ void GetTopK(Pair<T> topk[],
|
||||
const T *src,
|
||||
int idx,
|
||||
int dim,
|
||||
const Pair<T> &max,
|
||||
int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
@@ -598,10 +633,15 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[], const T *src, int idx,
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
__device__ __forceinline__ void
|
||||
ThreadGetTopK(Pair<T> topk[], int *beam, int beam_size, const T *src,
|
||||
bool *firstStep, bool *is_empty, Pair<T> *max, int dim,
|
||||
const int tid) {
|
||||
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
|
||||
int *beam,
|
||||
int beam_size,
|
||||
const T *src,
|
||||
bool *firstStep,
|
||||
bool *is_empty,
|
||||
Pair<T> *max,
|
||||
int dim,
|
||||
const int tid) {
|
||||
if (*beam > 0) {
|
||||
int length = (*beam) < beam_size ? *beam : beam_size;
|
||||
if (*firstStep) {
|
||||
@@ -616,22 +656,20 @@ ThreadGetTopK(Pair<T> topk[], int *beam, int beam_size, const T *src,
|
||||
}
|
||||
}
|
||||
if (!(*is_empty)) {
|
||||
GetTopK<T, BlockSize>(topk + MaxLength - *beam, src, tid, dim, *max,
|
||||
length);
|
||||
GetTopK<T, BlockSize>(
|
||||
topk + MaxLength - *beam, src, tid, dim, *max, length);
|
||||
}
|
||||
}
|
||||
|
||||
*max = topk[MaxLength - 1];
|
||||
if ((*max).id == -1)
|
||||
*is_empty = true;
|
||||
if ((*max).id == -1) *is_empty = true;
|
||||
*beam = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
|
||||
int delta,
|
||||
int width = warpSize) {
|
||||
__forceinline__ __device__ T
|
||||
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = WARP_SIZE) {
|
||||
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
|
||||
}
|
||||
|
||||
@@ -650,9 +688,15 @@ __forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
__device__ __forceinline__ void
|
||||
BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
|
||||
int *k, int *count, const int tid, const int wid, const int lane) {
|
||||
__device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
|
||||
Pair<T> topk[],
|
||||
Pair<T> beam_max[],
|
||||
int *beam,
|
||||
int *k,
|
||||
int *count,
|
||||
const int tid,
|
||||
const int wid,
|
||||
const int lane) {
|
||||
while (true) {
|
||||
__syncthreads();
|
||||
Pair<T> input_now = topk[0];
|
||||
@@ -667,8 +711,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
|
||||
: Pair<T>(std::numeric_limits<T>::min(), -1);
|
||||
if (wid == 0) {
|
||||
input_now = WarpReduce(input_now);
|
||||
if (lane == 0)
|
||||
shared_max[0] = input_now;
|
||||
if (lane == 0) shared_max[0] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
@@ -679,8 +722,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
|
||||
if (tid == tid_max) {
|
||||
(*beam)++;
|
||||
}
|
||||
if (--(*k) == 0)
|
||||
break;
|
||||
if (--(*k) == 0) break;
|
||||
__syncthreads();
|
||||
|
||||
if (tid == tid_max) {
|
||||
@@ -690,8 +732,7 @@ BlockReduce(Pair<T> shared_max[], Pair<T> topk[], Pair<T> beam_max[], int *beam,
|
||||
}
|
||||
|
||||
if (MaxLength < 5) {
|
||||
if (*beam >= MaxLength)
|
||||
break;
|
||||
if (*beam >= MaxLength) break;
|
||||
} else {
|
||||
unsigned mask = 0u;
|
||||
mask = __ballot_sync(FINAL_MASK, true);
|
||||
@@ -721,13 +762,18 @@ __device__ inline T exponential_transform(T val, T lambda) {
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
|
||||
__global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
|
||||
curandState_t *states, T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids, T *topk_scores,
|
||||
int vocab_size, int *count_iter,
|
||||
int *count_iter_begin, const int k,
|
||||
__global__ void KeMatrixTopPBeamTopK(const T *src,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids,
|
||||
T *topk_scores,
|
||||
int vocab_size,
|
||||
int *count_iter,
|
||||
int *count_iter_begin,
|
||||
const int k,
|
||||
const bool need_batch_random) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
@@ -761,11 +807,17 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
|
||||
}
|
||||
|
||||
while (top_num) {
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, TopPBeamTopK,
|
||||
src + offset, &firststep, &is_empty,
|
||||
&max, vocab_size, tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(shared_max, topk, beam_max, &beam,
|
||||
&top_num, &count, tid, wid, lane);
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
|
||||
&beam,
|
||||
TopPBeamTopK,
|
||||
src + offset,
|
||||
&firststep,
|
||||
&is_empty,
|
||||
&max,
|
||||
vocab_size,
|
||||
tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(
|
||||
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
|
||||
}
|
||||
if (tid == 0) {
|
||||
// printf("offset: %d\n", (int)seed_offset);
|
||||
@@ -817,13 +869,18 @@ __global__ void KeMatrixTopPBeamTopK(const T *src, const T *threshold,
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
|
||||
__global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold,
|
||||
curandState_t *states, T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids, T *topk_scores,
|
||||
int vocab_size, int *count_iter,
|
||||
int *count_iter_begin, const int k,
|
||||
__global__ void KeMatrixTopPBeamTopKFt(const T *src,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids,
|
||||
T *topk_scores,
|
||||
int vocab_size,
|
||||
int *count_iter,
|
||||
int *count_iter_begin,
|
||||
const int k,
|
||||
const bool need_batch_random) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
@@ -856,11 +913,17 @@ __global__ void KeMatrixTopPBeamTopKFt(const T *src, const T *threshold,
|
||||
}
|
||||
|
||||
while (top_num) {
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, TopPBeamTopK,
|
||||
src + bid * vocab_size, &firststep,
|
||||
&is_empty, &max, vocab_size, tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(shared_max, topk, beam_max, &beam,
|
||||
&top_num, &count, tid, wid, lane);
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
|
||||
&beam,
|
||||
TopPBeamTopK,
|
||||
src + bid * vocab_size,
|
||||
&firststep,
|
||||
&is_empty,
|
||||
&max,
|
||||
vocab_size,
|
||||
tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(
|
||||
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
|
||||
}
|
||||
if (tid == 0) {
|
||||
count_iter_begin[bid] = count_iter[bid];
|
||||
@@ -925,14 +988,20 @@ __global__ void FillIndex(T *indices, T num_rows, T num_cols) {
|
||||
}
|
||||
|
||||
template <typename T, int TopKMaxLength, int TopPBeamTopK>
|
||||
void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold,
|
||||
curandState_t *states, T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids, T *topk_scores,
|
||||
int vocab_size, int *count_iter,
|
||||
int *count_iter_begin, const int k,
|
||||
const int bs, const bool need_batch_random,
|
||||
void DispatchKeMatrixTopPBeamTopK(const T *src,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
T *top_ps,
|
||||
int64_t *out_id, // topk id
|
||||
T *out_val, // topk val
|
||||
int64_t *topk_ids,
|
||||
T *topk_scores,
|
||||
int vocab_size,
|
||||
int *count_iter,
|
||||
int *count_iter_begin,
|
||||
const int k,
|
||||
const int bs,
|
||||
const bool need_batch_random,
|
||||
const std::string &mode,
|
||||
cudaStream_t stream) {
|
||||
int BlockSize = GetBlockSize(vocab_size);
|
||||
@@ -940,23 +1009,43 @@ void DispatchKeMatrixTopPBeamTopK(const T *src, const T *threshold,
|
||||
switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(
|
||||
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
|
||||
<<<bs, kBlockDim, 0, stream>>>(
|
||||
src, threshold, states, top_ps, out_id, out_val, topk_ids,
|
||||
topk_scores, vocab_size, count_iter, count_iter_begin, k,
|
||||
need_batch_random));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
|
||||
<<<bs, kBlockDim, 0, stream>>>(src,
|
||||
threshold,
|
||||
states,
|
||||
top_ps,
|
||||
out_id,
|
||||
out_val,
|
||||
topk_ids,
|
||||
topk_scores,
|
||||
vocab_size,
|
||||
count_iter,
|
||||
count_iter_begin,
|
||||
k,
|
||||
need_batch_random));
|
||||
default:
|
||||
PD_THROW(
|
||||
"the input data shape has error in the topp_beam_topk kernel.");
|
||||
}
|
||||
} else {
|
||||
switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(
|
||||
KeMatrixTopPBeamTopK<T, TopKMaxLength, TopPBeamTopK, kBlockDim>
|
||||
<<<bs, kBlockDim, 0, stream>>>(
|
||||
src, threshold, states, top_ps, out_id, out_val, topk_ids,
|
||||
topk_scores, vocab_size, count_iter, count_iter_begin, k,
|
||||
need_batch_random));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_beam_topk kernel.");
|
||||
<<<bs, kBlockDim, 0, stream>>>(src,
|
||||
threshold,
|
||||
states,
|
||||
top_ps,
|
||||
out_id,
|
||||
out_val,
|
||||
topk_ids,
|
||||
topk_scores,
|
||||
vocab_size,
|
||||
count_iter,
|
||||
count_iter_begin,
|
||||
k,
|
||||
need_batch_random));
|
||||
default:
|
||||
PD_THROW(
|
||||
"the input data shape has error in the topp_beam_topk kernel.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -978,11 +1067,17 @@ struct BlockPrefixCallbackOp {
|
||||
};
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
__global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
|
||||
int64_t *out_id, const T *top_ps,
|
||||
const T *threshold, curandState_t *states,
|
||||
const int p_num, const int vocab_size,
|
||||
const bool need_batch_random, int *count_iter,
|
||||
__global__ void topp_sampling(T *sorted_probs,
|
||||
int64_t *sorted_id,
|
||||
T *out_val,
|
||||
int64_t *out_id,
|
||||
const T *top_ps,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
const int p_num,
|
||||
const int vocab_size,
|
||||
const bool need_batch_random,
|
||||
int *count_iter,
|
||||
int *count_iter_begin) {
|
||||
__shared__ int stop_shared;
|
||||
const int tid = threadIdx.x;
|
||||
@@ -1063,11 +1158,17 @@ __global__ void topp_sampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
__global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
|
||||
T *out_val, int64_t *out_id, const T *top_ps,
|
||||
const T *threshold, curandState_t *states,
|
||||
const int p_num, const int vocab_size,
|
||||
const bool need_batch_random, int *count_iter,
|
||||
__global__ void topp_sampling_ft(T *sorted_probs,
|
||||
int64_t *sorted_id,
|
||||
T *out_val,
|
||||
int64_t *out_id,
|
||||
const T *top_ps,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
const int p_num,
|
||||
const int vocab_size,
|
||||
const bool need_batch_random,
|
||||
int *count_iter,
|
||||
int *count_iter_begin) {
|
||||
__shared__ int stop_shared;
|
||||
__shared__ float rand_p;
|
||||
@@ -1146,7 +1247,7 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
|
||||
}
|
||||
}
|
||||
if (!skip) {
|
||||
int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0
|
||||
int active_lane_id = 32 - __popc(selected_shared[warp_id]); // first not 0
|
||||
if (lane_id == active_lane_id) {
|
||||
float val = static_cast<float>(sorted_probs[offset + i_activate]);
|
||||
if (val < threshold_now) {
|
||||
@@ -1167,36 +1268,63 @@ __global__ void topp_sampling_ft(T *sorted_probs, int64_t *sorted_id,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DispatchTopPSampling(T *sorted_probs, int64_t *sorted_id, T *out_val,
|
||||
int64_t *out_id, const T *top_ps, const T *threshold,
|
||||
curandState_t *states, const int p_num,
|
||||
const int vocab_size, const int bs,
|
||||
const bool need_batch_random, int *count_iter,
|
||||
int *count_iter_begin, const std::string &mode,
|
||||
void DispatchTopPSampling(T *sorted_probs,
|
||||
int64_t *sorted_id,
|
||||
T *out_val,
|
||||
int64_t *out_id,
|
||||
const T *top_ps,
|
||||
const T *threshold,
|
||||
curandState_t *states,
|
||||
const int p_num,
|
||||
const int vocab_size,
|
||||
const int bs,
|
||||
const bool need_batch_random,
|
||||
int *count_iter,
|
||||
int *count_iter_begin,
|
||||
const std::string &mode,
|
||||
cudaStream_t stream) {
|
||||
int BlockSize = GetBlockSize(vocab_size);
|
||||
if (mode == "truncated") {
|
||||
switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(topp_sampling_ft<T, kBlockDim>
|
||||
<<<bs, kBlockDim, 0, stream>>>(
|
||||
sorted_probs, sorted_id, out_val, out_id, top_ps,
|
||||
threshold, states, p_num, vocab_size,
|
||||
need_batch_random, count_iter, count_iter_begin));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_sampling kernel.");
|
||||
<<<bs, kBlockDim, 0, stream>>>(sorted_probs,
|
||||
sorted_id,
|
||||
out_val,
|
||||
out_id,
|
||||
top_ps,
|
||||
threshold,
|
||||
states,
|
||||
p_num,
|
||||
vocab_size,
|
||||
need_batch_random,
|
||||
count_iter,
|
||||
count_iter_begin));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_sampling kernel.");
|
||||
}
|
||||
} else {
|
||||
switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(topp_sampling<T, kBlockDim><<<bs, kBlockDim, 0, stream>>>(
|
||||
sorted_probs, sorted_id, out_val, out_id, top_ps, threshold, states,
|
||||
p_num, vocab_size, need_batch_random, count_iter, count_iter_begin));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_sampling kernel.");
|
||||
FIXED_BLOCK_DIM(topp_sampling<T, kBlockDim>
|
||||
<<<bs, kBlockDim, 0, stream>>>(sorted_probs,
|
||||
sorted_id,
|
||||
out_val,
|
||||
out_id,
|
||||
top_ps,
|
||||
threshold,
|
||||
states,
|
||||
p_num,
|
||||
vocab_size,
|
||||
need_batch_random,
|
||||
count_iter,
|
||||
count_iter_begin));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the topp_sampling kernel.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed,
|
||||
__global__ void air_topp_setup_kernel(curandState_t *state,
|
||||
int64_t *seed,
|
||||
const int bs) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
|
||||
@@ -1204,8 +1332,10 @@ __global__ void air_topp_setup_kernel(curandState_t *state, int64_t *seed,
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed,
|
||||
const uint64_t offset, const int bs,
|
||||
__global__ void air_topp_setup_kernel(curandState_t *state,
|
||||
const uint64_t seed,
|
||||
const uint64_t offset,
|
||||
const int bs,
|
||||
const bool need_batch_random) {
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = idx; i < bs; i += gridDim.x * blockDim.x) {
|
||||
@@ -1217,7 +1347,8 @@ __global__ void air_topp_setup_kernel(curandState_t *state, const uint64_t seed,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T> __global__ void print_kernel(T *input, int size) {
|
||||
template <typename T>
|
||||
__global__ void print_kernel(T *input, int size) {
|
||||
printf("[");
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (i != size - 1) {
|
||||
@@ -1229,11 +1360,14 @@ template <typename T> __global__ void print_kernel(T *input, int size) {
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
std::vector<paddle::Tensor>
|
||||
LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
const paddle::optional<paddle::Tensor> &threshold,
|
||||
const paddle::optional<paddle::Tensor> &topp_seed, int seed,
|
||||
int k, const std::string &mode) {
|
||||
std::vector<paddle::Tensor> LaunchTopPSampling(
|
||||
const paddle::Tensor &x,
|
||||
const paddle::Tensor &ps,
|
||||
const paddle::optional<paddle::Tensor> &threshold,
|
||||
const paddle::optional<paddle::Tensor> &topp_seed,
|
||||
int seed,
|
||||
int k,
|
||||
const std::string &mode) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -1259,8 +1393,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(FillIndex<int64_t><<<bs, kBlockDim, 0, stream>>>(
|
||||
inds_input.data<int64_t>(), bs, vocab_size));
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the FillIndex kernel.");
|
||||
default:
|
||||
PD_THROW("the input data shape has error in the FillIndex kernel.");
|
||||
}
|
||||
int64_t *infer_seed =
|
||||
topp_seed ? const_cast<int64_t *>(topp_seed.get().data<int64_t>())
|
||||
@@ -1270,7 +1404,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
|
||||
phi::Allocator::AllocationPtr curand_states_buf{nullptr};
|
||||
curand_states_buf = phi::memory_utils::Alloc(
|
||||
x.place(), bs * sizeof(curandState_t),
|
||||
x.place(),
|
||||
bs * sizeof(curandState_t),
|
||||
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
|
||||
states = reinterpret_cast<curandState_t *>(curand_states_buf->ptr());
|
||||
|
||||
@@ -1290,11 +1425,11 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
auto seed_offset = gen_cuda->IncrementOffset(increment);
|
||||
seed_now = seed_offset.first;
|
||||
offset = seed_offset.second;
|
||||
air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs,
|
||||
need_batch_random);
|
||||
air_topp_setup_kernel<<<1, 256, 0, stream>>>(
|
||||
states, seed_now, offset, bs, need_batch_random);
|
||||
} else {
|
||||
air_topp_setup_kernel<<<1, 256, 0, stream>>>(states, seed_now, offset, bs,
|
||||
need_batch_random);
|
||||
air_topp_setup_kernel<<<1, 256, 0, stream>>>(
|
||||
states, seed_now, offset, bs, need_batch_random);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1313,13 +1448,21 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
|
||||
DispatchKeMatrixTopPBeamTopK<DataType_, TopKMaxLength, TopPBeamTopK>(
|
||||
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
|
||||
reinterpret_cast<const DataType_ *>(threshold_data), states,
|
||||
reinterpret_cast<DataType_ *>(ps_now.data<data_t>()), ids.data<int64_t>(),
|
||||
reinterpret_cast<const DataType_ *>(threshold_data),
|
||||
states,
|
||||
reinterpret_cast<DataType_ *>(ps_now.data<data_t>()),
|
||||
ids.data<int64_t>(),
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
topk_ids.data<int64_t>(),
|
||||
reinterpret_cast<DataType_ *>(topk_scores.data<data_t>()), vocab_size,
|
||||
count_iter.data<int>(), count_iter_begin.data<int>(), k, bs,
|
||||
need_batch_random, mode, stream);
|
||||
reinterpret_cast<DataType_ *>(topk_scores.data<data_t>()),
|
||||
vocab_size,
|
||||
count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(),
|
||||
k,
|
||||
bs,
|
||||
need_batch_random,
|
||||
mode,
|
||||
stream);
|
||||
|
||||
static_assert(std::is_same<DataType_, float>::value,
|
||||
"air_topp only supports float now!");
|
||||
@@ -1328,7 +1471,8 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
constexpr int INIT_BLOCK_SIZE = 1024;
|
||||
phi::Allocator::AllocationPtr counter_ptr{nullptr};
|
||||
counter_ptr = phi::memory_utils::Alloc(
|
||||
x.place(), bs * sizeof(Counter<DataType_>),
|
||||
x.place(),
|
||||
bs * sizeof(Counter<DataType_>),
|
||||
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
|
||||
Counter<DataType_> *counters =
|
||||
reinterpret_cast<Counter<DataType_> *>(counter_ptr->ptr());
|
||||
@@ -1346,67 +1490,93 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
paddle::empty({bs, buf_len}, paddle::DataType::INT32, x.place());
|
||||
|
||||
air_topp_init<float, BitsPerPass><<<bs, INIT_BLOCK_SIZE, 0, stream>>>(
|
||||
counters, reinterpret_cast<float *>(histograms.data<data_t>()),
|
||||
counters,
|
||||
reinterpret_cast<float *>(histograms.data<data_t>()),
|
||||
count_histograms.data<int32_t>(),
|
||||
reinterpret_cast<const float *>(x.data<data_t>()),
|
||||
reinterpret_cast<const float *>(ps.data<data_t>()), states, bs,
|
||||
vocab_size, buf_len, numBuckets);
|
||||
reinterpret_cast<const float *>(ps.data<data_t>()),
|
||||
states,
|
||||
bs,
|
||||
vocab_size,
|
||||
buf_len,
|
||||
numBuckets);
|
||||
|
||||
constexpr int VecSize = 16 / sizeof(data_t);
|
||||
// TODO: good block_num
|
||||
const int max_block_num_vocab =
|
||||
ceilDiv(vocab_size, SAMPLING_BLOCK_SIZE * VecSize);
|
||||
auto kernel = air_topp_sampling<data_t, BitsPerPass, SAMPLING_BLOCK_SIZE,
|
||||
numBuckets, 0>;
|
||||
auto kernel = air_topp_sampling<data_t,
|
||||
BitsPerPass,
|
||||
SAMPLING_BLOCK_SIZE,
|
||||
numBuckets,
|
||||
0>;
|
||||
const int dev_id = 0;
|
||||
int sm_count;
|
||||
int act_blocks_per_sm;
|
||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&act_blocks_per_sm, kernel,
|
||||
SAMPLING_BLOCK_SIZE, 0);
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&act_blocks_per_sm, kernel, SAMPLING_BLOCK_SIZE, 0);
|
||||
assert(act_blocks_per_sm > 1);
|
||||
const int block_per_wave = sm_count * act_blocks_per_sm;
|
||||
const int block_num_vocab =
|
||||
std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!!
|
||||
std::min(max_block_num_vocab, block_per_wave * 4 / bs); // !!!
|
||||
dim3 grid(block_num_vocab, bs);
|
||||
constexpr int numPasses = calcNumPasses<data_t, BitsPerPass>();
|
||||
for (int pass = 0; pass < numPasses; ++pass) {
|
||||
if (pass == 0) {
|
||||
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
|
||||
air_topp_sampling<DataType_,
|
||||
BitsPerPass,
|
||||
SAMPLING_BLOCK_SIZE,
|
||||
numBuckets,
|
||||
0><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
|
||||
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
counters,
|
||||
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
count_histograms.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
ids.data<int64_t>(),
|
||||
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
|
||||
id_buf1.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
|
||||
id_buf2.data<int>(), count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(), buf_len);
|
||||
id_buf2.data<int>(),
|
||||
count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(),
|
||||
buf_len);
|
||||
} else if (pass == 1) {
|
||||
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
|
||||
air_topp_sampling<DataType_,
|
||||
BitsPerPass,
|
||||
SAMPLING_BLOCK_SIZE,
|
||||
numBuckets,
|
||||
1><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
|
||||
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
counters,
|
||||
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
count_histograms.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
ids.data<int64_t>(),
|
||||
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
|
||||
id_buf1.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
|
||||
id_buf2.data<int>(), count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(), buf_len);
|
||||
id_buf2.data<int>(),
|
||||
count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(),
|
||||
buf_len);
|
||||
} else if (pass == 2) {
|
||||
air_topp_sampling<DataType_, BitsPerPass, SAMPLING_BLOCK_SIZE, numBuckets,
|
||||
air_topp_sampling<DataType_,
|
||||
BitsPerPass,
|
||||
SAMPLING_BLOCK_SIZE,
|
||||
numBuckets,
|
||||
2><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(
|
||||
counters, reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
counters,
|
||||
reinterpret_cast<DataType_ *>(histograms.data<data_t>()),
|
||||
count_histograms.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(out.data<data_t>()),
|
||||
ids.data<int64_t>(),
|
||||
reinterpret_cast<DataType_ *>(buf1.data<data_t>()),
|
||||
id_buf1.data<int>(),
|
||||
reinterpret_cast<DataType_ *>(buf2.data<data_t>()),
|
||||
id_buf2.data<int>(), count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(), buf_len);
|
||||
id_buf2.data<int>(),
|
||||
count_iter.data<int>(),
|
||||
count_iter_begin.data<int>(),
|
||||
buf_len);
|
||||
} else {
|
||||
PD_THROW("pass must be 0,1 or 2!");
|
||||
}
|
||||
@@ -1414,52 +1584,60 @@ LaunchTopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
return {out, ids};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
TopPSampling(const paddle::Tensor &x, const paddle::Tensor &ps,
|
||||
const paddle::optional<paddle::Tensor> &threshold,
|
||||
const paddle::optional<paddle::Tensor> &topp_seed, int seed, int k,
|
||||
const std::string &mode) {
|
||||
std::vector<paddle::Tensor> TopPSampling(
|
||||
const paddle::Tensor &x,
|
||||
const paddle::Tensor &ps,
|
||||
const paddle::optional<paddle::Tensor> &threshold,
|
||||
const paddle::optional<paddle::Tensor> &topp_seed,
|
||||
int seed,
|
||||
int k,
|
||||
const std::string &mode) {
|
||||
switch (x.type()) {
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return LaunchTopPSampling<paddle::DataType::FLOAT32>(
|
||||
x, ps, threshold, topp_seed, seed, k, mode);
|
||||
}
|
||||
// case paddle::DataType::BFLOAT16: {
|
||||
// return LaunchTopPSampling<paddle::DataType::BFLOAT16>(x, ps, threshold,
|
||||
// topp_seed, seed, k, mode);
|
||||
// }
|
||||
// case paddle::DataType::FLOAT16: {
|
||||
// return LaunchTopPSampling<paddle::DataType::FLOAT16>(x, ps, threshold,
|
||||
// topp_seed, seed, k, mode);
|
||||
// }
|
||||
default: {
|
||||
PD_THROW("NOT supported data type. Only support float. ");
|
||||
break;
|
||||
}
|
||||
case paddle::DataType::FLOAT32: {
|
||||
return LaunchTopPSampling<paddle::DataType::FLOAT32>(
|
||||
x, ps, threshold, topp_seed, seed, k, mode);
|
||||
}
|
||||
// case paddle::DataType::BFLOAT16: {
|
||||
// return LaunchTopPSampling<paddle::DataType::BFLOAT16>(x, ps,
|
||||
// threshold, topp_seed, seed, k, mode);
|
||||
// }
|
||||
// case paddle::DataType::FLOAT16: {
|
||||
// return LaunchTopPSampling<paddle::DataType::FLOAT16>(x, ps,
|
||||
// threshold, topp_seed, seed, k, mode);
|
||||
// }
|
||||
default: {
|
||||
PD_THROW("NOT supported data type. Only support float. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetTopPSamplingShape(
|
||||
const std::vector<int64_t> &x_shape, const std::vector<int64_t> &ps_shape,
|
||||
const std::vector<int64_t> &x_shape,
|
||||
const std::vector<int64_t> &ps_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &threshold_shape,
|
||||
const paddle::optional<std::vector<int64_t>> &topp_seed_shape, int seed,
|
||||
const paddle::optional<std::vector<int64_t>> &topp_seed_shape,
|
||||
int seed,
|
||||
int k) {
|
||||
int bs = x_shape[0];
|
||||
int vocab_size = x_shape[1];
|
||||
return {{bs, 1}, {bs, 1}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
GetTopPSamplingDtype(const paddle::DataType &x_dytpe,
|
||||
const paddle::DataType &ps_dtype,
|
||||
const paddle::optional<paddle::DataType> &threshold_dtype,
|
||||
const paddle::optional<paddle::DataType> &topp_seed_dtype,
|
||||
int seed, int k) {
|
||||
std::vector<paddle::DataType> GetTopPSamplingDtype(
|
||||
const paddle::DataType &x_dytpe,
|
||||
const paddle::DataType &ps_dtype,
|
||||
const paddle::optional<paddle::DataType> &threshold_dtype,
|
||||
const paddle::optional<paddle::DataType> &topp_seed_dtype,
|
||||
int seed,
|
||||
int k) {
|
||||
return {x_dytpe, paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(air_top_p_sampling)
|
||||
.Inputs({"x", "ps", paddle::Optional("threshold"),
|
||||
.Inputs({"x",
|
||||
"ps",
|
||||
paddle::Optional("threshold"),
|
||||
paddle::Optional("topp_seed")})
|
||||
.Outputs({"out", "ids"})
|
||||
.Attrs({"seed: int", "k: int", "mode: std::string"})
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T
|
||||
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
|
||||
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = WARP_SIZE) {
|
||||
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user