mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix top_p_candidates (#5400)
Co-authored-by: freeliuzc <lzc842650834@gmail.com>
This commit is contained in:
@@ -19,113 +19,113 @@
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T
|
||||
CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {
|
||||
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
|
||||
return __shfl_down_sync(mask, val, static_cast<unsigned>(delta), width);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync(
|
||||
unsigned mask, phi::dtype::float16 val, int delta, int width) {
|
||||
return paddle::float16(__shfl_down_sync(
|
||||
mask, val.to_half(), static_cast<unsigned>(delta), width));
|
||||
return paddle::float16(__shfl_down_sync(
|
||||
mask, val.to_half(), static_cast<unsigned>(delta), width));
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync(
|
||||
unsigned mask, phi::dtype::bfloat16 val, int delta, int width) {
|
||||
return paddle::bfloat16(__shfl_down_sync(
|
||||
mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
|
||||
return paddle::bfloat16(__shfl_down_sync(
|
||||
mask, val.to_nv_bfloat16(), static_cast<unsigned>(delta), width));
|
||||
}
|
||||
|
||||
struct BlockPrefixCallbackOp {
|
||||
// Running prefix
|
||||
float running_total;
|
||||
// Constructor
|
||||
__device__ BlockPrefixCallbackOp(float running_total)
|
||||
: running_total(running_total) {}
|
||||
// Callback operator to be entered by the first warp of threads in the
|
||||
// block. Thread-0 is responsible for returning a value for seeding the
|
||||
// block-wide scan.
|
||||
__device__ float operator()(float block_aggregate) {
|
||||
float old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
// Running prefix
|
||||
float running_total;
|
||||
// Constructor
|
||||
__device__ BlockPrefixCallbackOp(float running_total)
|
||||
: running_total(running_total) {}
|
||||
// Callback operator to be entered by the first warp of threads in the
|
||||
// block. Thread-0 is responsible for returning a value for seeding the
|
||||
// block-wide scan.
|
||||
__device__ float operator()(float block_aggregate) {
|
||||
float old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
#define FINAL_MASK 0xFFFFFFFF
|
||||
|
||||
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
||||
case (dim): { \
|
||||
constexpr auto kBlockDim = (dim); \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
|
||||
case (dim): { \
|
||||
constexpr auto kBlockDim = (dim); \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
|
||||
#define FIXED_BLOCK_DIM(...) \
|
||||
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
|
||||
#define FIXED_BLOCK_DIM(...) \
|
||||
FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
|
||||
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
|
||||
|
||||
#define FIXED_TOPK_BASE(topk, ...) \
|
||||
case (topk): { \
|
||||
constexpr auto kTopK = topk; \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
#define FIXED_TOPK_BASE(topk, ...) \
|
||||
case (topk): { \
|
||||
constexpr auto kTopK = topk; \
|
||||
__VA_ARGS__; \
|
||||
} break
|
||||
|
||||
#define FIXED_TOPK(...) \
|
||||
FIXED_TOPK_BASE(1, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(6, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(7, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(9, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(20, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(30, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(40, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(50, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(60, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(70, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(80, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(90, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(100, ##__VA_ARGS__);
|
||||
#define FIXED_TOPK(...) \
|
||||
FIXED_TOPK_BASE(1, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(2, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(3, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(4, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(5, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(6, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(7, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(8, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(9, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(10, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(20, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(30, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(40, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(50, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(60, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(70, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(80, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(90, ##__VA_ARGS__); \
|
||||
FIXED_TOPK_BASE(100, ##__VA_ARGS__);
|
||||
|
||||
struct SegmentOffsetIter {
|
||||
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
||||
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
|
||||
|
||||
__host__ __device__ __forceinline__ int operator()(int idx) const {
|
||||
return idx * num_cols_;
|
||||
}
|
||||
__host__ __device__ __forceinline__ int operator()(int idx) const {
|
||||
return idx * num_cols_;
|
||||
}
|
||||
|
||||
int num_cols_;
|
||||
int num_cols_;
|
||||
};
|
||||
|
||||
inline int div_up(int a, int n) { return (a + n - 1) / n; }
|
||||
|
||||
template <typename T>
|
||||
__global__ void FillIndex(T* indices, T num_rows, T num_cols) {
|
||||
int col_id = threadIdx.x;
|
||||
int row_id = blockIdx.x;
|
||||
int col_id = threadIdx.x;
|
||||
int row_id = blockIdx.x;
|
||||
|
||||
for (T j = row_id; j < num_rows; j += gridDim.x) {
|
||||
for (T i = col_id; i < num_cols; i += blockDim.x) {
|
||||
indices[j * num_cols + i] = i;
|
||||
}
|
||||
for (T j = row_id; j < num_rows; j += gridDim.x) {
|
||||
for (T i = col_id; i < num_cols; i += blockDim.x) {
|
||||
indices[j * num_cols + i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SetCountIter(int* count_iter, int num) {
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
int idx = bid * blockDim.x + tid;
|
||||
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
|
||||
count_iter[i] = i;
|
||||
}
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
int idx = bid * blockDim.x + tid;
|
||||
for (int i = idx; i < num; i += gridDim.x * blockDim.x) {
|
||||
count_iter[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
@@ -137,148 +137,146 @@ __global__ void top_p_candidates_kernel(T* sorted_probs,
|
||||
const int vocab_size,
|
||||
const float topp,
|
||||
const int candidates_len) {
|
||||
__shared__ int stop_shared;
|
||||
__shared__ float rand_p;
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int warp_id = tid / 32;
|
||||
__shared__ int stop_shared;
|
||||
__shared__ float rand_p;
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
constexpr int NUM_WARPS = BLOCK_SIZE / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int warp_id = tid / 32;
|
||||
|
||||
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
|
||||
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
|
||||
__shared__ uint32_t selected_shared[NUM_WARPS];
|
||||
typedef cub::BlockScan<float, BLOCK_SIZE> BlockScan;
|
||||
typedef cub::BlockReduce<int, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage_reduce;
|
||||
__shared__ uint32_t selected_shared[NUM_WARPS];
|
||||
|
||||
if (lane_id == 0) {
|
||||
selected_shared[warp_id] = 0;
|
||||
if (lane_id == 0) {
|
||||
selected_shared[warp_id] = 0;
|
||||
}
|
||||
|
||||
// Initialize running total
|
||||
BlockPrefixCallbackOp prefix_op(0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int offset = bid * vocab_size;
|
||||
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int i_activate = 0;
|
||||
float thread_offset = 0;
|
||||
for (int i = tid; i < end; i += BLOCK_SIZE) {
|
||||
float thread_count =
|
||||
(i < vocab_size) ? static_cast<float>(sorted_probs[offset + i]) : 0.f;
|
||||
|
||||
BlockScan(temp_storage)
|
||||
.InclusiveSum(thread_count, thread_offset, prefix_op);
|
||||
|
||||
if (i < candidates_len) {
|
||||
out_id[bid * candidates_len + i] = sorted_id[offset + i];
|
||||
out_val[bid * candidates_len + i] = sorted_probs[offset + i];
|
||||
}
|
||||
|
||||
// Initialize running total
|
||||
BlockPrefixCallbackOp prefix_op(0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int offset = bid * vocab_size;
|
||||
int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int i_activate = 0;
|
||||
float thread_offset = 0;
|
||||
for (int i = tid; i < end; i += BLOCK_SIZE) {
|
||||
float thread_count = (i < vocab_size)
|
||||
? static_cast<float>(sorted_probs[offset + i])
|
||||
: 0.f;
|
||||
|
||||
BlockScan(temp_storage)
|
||||
.InclusiveSum(thread_count, thread_offset, prefix_op);
|
||||
|
||||
if (i < candidates_len) {
|
||||
out_id[bid * candidates_len + i] = sorted_id[offset + i];
|
||||
out_val[bid * candidates_len + i] = sorted_probs[offset + i];
|
||||
}
|
||||
|
||||
uint32_t activate_mask =
|
||||
__ballot_sync(FINAL_MASK, topp <= thread_offset);
|
||||
i_activate = i;
|
||||
if (activate_mask != 0 || i >= candidates_len) {
|
||||
if (lane_id == 0) {
|
||||
atomicAdd(&stop_shared, 1);
|
||||
selected_shared[warp_id] = activate_mask;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (stop_shared > 0) {
|
||||
break;
|
||||
}
|
||||
uint32_t activate_mask = __ballot_sync(FINAL_MASK, topp <= thread_offset);
|
||||
i_activate = i;
|
||||
if (activate_mask != 0 || i >= candidates_len) {
|
||||
if (lane_id == 0) {
|
||||
atomicAdd(&stop_shared, 1);
|
||||
selected_shared[warp_id] = activate_mask;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
bool skip = (selected_shared[warp_id] > 0) ? false : true;
|
||||
for (int i = 0; i < warp_id; i++) {
|
||||
if (selected_shared[i] != 0) {
|
||||
// If the previous has stopped, skip the current warp
|
||||
skip = true;
|
||||
}
|
||||
if (stop_shared > 0) {
|
||||
break;
|
||||
}
|
||||
if (!skip) {
|
||||
int active_lane_id =
|
||||
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
|
||||
if (lane_id == active_lane_id) {
|
||||
actual_candidates_lens[bid] = i_activate + 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
bool skip = (selected_shared[warp_id] > 0) ? false : true;
|
||||
for (int i = 0; i < warp_id; i++) {
|
||||
if (selected_shared[i] != 0) {
|
||||
// If the previous has stopped, skip the current warp
|
||||
skip = true;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
// printf("actual_candidates_lens[%d] %d\n", bid,
|
||||
// actual_candidates_lens[bid]);
|
||||
if (actual_candidates_lens[bid] == 0) {
|
||||
actual_candidates_lens[bid] = candidates_len;
|
||||
}
|
||||
}
|
||||
if (!skip) {
|
||||
int active_lane_id =
|
||||
WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0
|
||||
if (lane_id == active_lane_id) {
|
||||
actual_candidates_lens[bid] = i_activate + 1;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
// printf("actual_candidates_lens[%d] %d\n", bid,
|
||||
// actual_candidates_lens[bid]);
|
||||
if (actual_candidates_lens[bid] == 0) {
|
||||
actual_candidates_lens[bid] = candidates_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Pair {
|
||||
__device__ __forceinline__ Pair() {}
|
||||
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
|
||||
__device__ __forceinline__ Pair() {}
|
||||
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
|
||||
|
||||
__device__ __forceinline__ void set(T value, int id) {
|
||||
this->v = value;
|
||||
this->id = id;
|
||||
}
|
||||
__device__ __forceinline__ void set(T value, int id) {
|
||||
this->v = value;
|
||||
this->id = id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void operator=(const Pair<T>& in) {
|
||||
v = in.v;
|
||||
id = in.id;
|
||||
}
|
||||
__device__ __forceinline__ void operator=(const Pair<T>& in) {
|
||||
v = in.v;
|
||||
id = in.id;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator<(const T value) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(value));
|
||||
}
|
||||
__device__ __forceinline__ bool operator<(const T value) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(value));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator>(const T value) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(value));
|
||||
}
|
||||
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id > in.id));
|
||||
}
|
||||
__device__ __forceinline__ bool operator>(const T value) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(value));
|
||||
}
|
||||
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) < static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id > in.id));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id < in.id));
|
||||
}
|
||||
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
|
||||
return (static_cast<float>(v) > static_cast<float>(in.v)) ||
|
||||
((static_cast<float>(v) == static_cast<float>(in.v)) &&
|
||||
(id < in.id));
|
||||
}
|
||||
|
||||
T v;
|
||||
int id;
|
||||
T v;
|
||||
int id;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void AddTo(Pair<T> topk[],
|
||||
const Pair<T>& p,
|
||||
int beam_size) {
|
||||
for (int k = beam_size - 2; k >= 0; k--) {
|
||||
if (topk[k] < p) {
|
||||
topk[k + 1] = topk[k];
|
||||
} else {
|
||||
topk[k + 1] = p;
|
||||
return;
|
||||
}
|
||||
for (int k = beam_size - 2; k >= 0; k--) {
|
||||
if (topk[k] < p) {
|
||||
topk[k + 1] = topk[k];
|
||||
} else {
|
||||
topk[k + 1] = p;
|
||||
return;
|
||||
}
|
||||
topk[0] = p;
|
||||
}
|
||||
topk[0] = p;
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
__device__ __forceinline__ void GetTopK(
|
||||
Pair<T> topk[], const T* src, int idx, int dim, int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
idx += BlockSize;
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
idx += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int BlockSize>
|
||||
@@ -288,15 +286,15 @@ __device__ __forceinline__ void GetTopK(Pair<T> topk[],
|
||||
int dim,
|
||||
const Pair<T>& max,
|
||||
int beam_size) {
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
if (tmp < max) {
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
}
|
||||
idx += BlockSize;
|
||||
while (idx < dim) {
|
||||
if (topk[beam_size - 1] < src[idx]) {
|
||||
Pair<T> tmp(src[idx], idx);
|
||||
if (tmp < max) {
|
||||
AddTo<T>(topk, tmp, beam_size);
|
||||
}
|
||||
}
|
||||
idx += BlockSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
@@ -309,43 +307,43 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[],
|
||||
Pair<T>* max,
|
||||
int dim,
|
||||
const int tid) {
|
||||
if (*beam > 0) {
|
||||
int length = (*beam) < beam_size ? *beam : beam_size;
|
||||
if (*firstStep) {
|
||||
*firstStep = false;
|
||||
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
|
||||
if (*beam > 0) {
|
||||
int length = (*beam) < beam_size ? *beam : beam_size;
|
||||
if (*firstStep) {
|
||||
*firstStep = false;
|
||||
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
|
||||
} else {
|
||||
for (int k = 0; k < MaxLength; k++) {
|
||||
if (k < MaxLength - (*beam)) {
|
||||
topk[k] = topk[k + *beam];
|
||||
} else {
|
||||
for (int k = 0; k < MaxLength; k++) {
|
||||
if (k < MaxLength - (*beam)) {
|
||||
topk[k] = topk[k + *beam];
|
||||
} else {
|
||||
topk[k].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
}
|
||||
if (!(*is_empty)) {
|
||||
GetTopK<T, BlockSize>(
|
||||
topk + MaxLength - *beam, src, tid, dim, *max, length);
|
||||
}
|
||||
topk[k].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
|
||||
*max = topk[MaxLength - 1];
|
||||
if ((*max).id == -1) *is_empty = true;
|
||||
*beam = 0;
|
||||
}
|
||||
if (!(*is_empty)) {
|
||||
GetTopK<T, BlockSize>(
|
||||
topk + MaxLength - *beam, src, tid, dim, *max, length);
|
||||
}
|
||||
}
|
||||
|
||||
*max = topk[MaxLength - 1];
|
||||
if ((*max).id == -1) *is_empty = true;
|
||||
*beam = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ Pair<T> WarpReduce(Pair<T> input) {
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset);
|
||||
int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset);
|
||||
if (static_cast<float>(input.v) < static_cast<float>(tmp_val)) {
|
||||
input.v = tmp_val;
|
||||
input.id = tmp_id;
|
||||
}
|
||||
for (int offset = 16; offset > 0; offset >>= 1) {
|
||||
T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset);
|
||||
int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset);
|
||||
if (static_cast<float>(input.v) < static_cast<float>(tmp_val)) {
|
||||
input.v = tmp_val;
|
||||
input.id = tmp_id;
|
||||
}
|
||||
return input;
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int BlockSize>
|
||||
@@ -358,52 +356,51 @@ __device__ __forceinline__ void BlockReduce(Pair<T> shared_max[],
|
||||
const int tid,
|
||||
const int wid,
|
||||
const int lane) {
|
||||
while (true) {
|
||||
__syncthreads();
|
||||
Pair<T> input_now = topk[0];
|
||||
input_now = WarpReduce(input_now);
|
||||
while (true) {
|
||||
__syncthreads();
|
||||
Pair<T> input_now = topk[0];
|
||||
input_now = WarpReduce(input_now);
|
||||
|
||||
if (lane == 0) {
|
||||
shared_max[wid] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
input_now = (tid < BlockSize / 32)
|
||||
? shared_max[lane]
|
||||
: Pair<T>(std::numeric_limits<T>::min(), -1);
|
||||
if (wid == 0) {
|
||||
input_now = WarpReduce(input_now);
|
||||
if (lane == 0) shared_max[0] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
beam_max[*count] = shared_max[0];
|
||||
(*count)++;
|
||||
}
|
||||
int tid_max = shared_max[0].id % BlockSize;
|
||||
if (tid == tid_max) {
|
||||
(*beam)++;
|
||||
}
|
||||
if (--(*k) == 0) break;
|
||||
__syncthreads();
|
||||
|
||||
if (tid == tid_max) {
|
||||
if (*beam < MaxLength) {
|
||||
topk[0] = topk[*beam];
|
||||
}
|
||||
}
|
||||
|
||||
if (MaxLength < 5) {
|
||||
if (*beam >= MaxLength) break;
|
||||
} else {
|
||||
unsigned mask = 0u;
|
||||
mask = __ballot_sync(FINAL_MASK, true);
|
||||
if (tid_max / 32 == wid) {
|
||||
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) ==
|
||||
MaxLength)
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (lane == 0) {
|
||||
shared_max[wid] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
input_now = (tid < BlockSize / 32)
|
||||
? shared_max[lane]
|
||||
: Pair<T>(std::numeric_limits<T>::min(), -1);
|
||||
if (wid == 0) {
|
||||
input_now = WarpReduce(input_now);
|
||||
if (lane == 0) shared_max[0] = input_now;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
beam_max[*count] = shared_max[0];
|
||||
(*count)++;
|
||||
}
|
||||
int tid_max = shared_max[0].id % BlockSize;
|
||||
if (tid == tid_max) {
|
||||
(*beam)++;
|
||||
}
|
||||
if (--(*k) == 0) break;
|
||||
__syncthreads();
|
||||
|
||||
if (tid == tid_max) {
|
||||
if (*beam < MaxLength) {
|
||||
topk[0] = topk[*beam];
|
||||
}
|
||||
}
|
||||
|
||||
if (MaxLength < 5) {
|
||||
if (*beam >= MaxLength) break;
|
||||
} else {
|
||||
unsigned mask = 0u;
|
||||
mask = __ballot_sync(FINAL_MASK, true);
|
||||
if (tid_max / 32 == wid) {
|
||||
if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength)
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MaxLength, int TopPBeamTopK, int BlockSize>
|
||||
@@ -417,70 +414,66 @@ __global__ void KeMatrixTopPBeamTopKFt(
|
||||
int vocab_size,
|
||||
const int max_cadidate_len,
|
||||
const int max_seq_len) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
const int ori_token_id = token_id + output_padding_offset[token_id];
|
||||
const int bid = ori_token_id / max_seq_len;
|
||||
|
||||
int top_num = TopPBeamTopK;
|
||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||
int top_num = TopPBeamTopK;
|
||||
float top_p_value = static_cast<float>(top_ps[bid]);
|
||||
|
||||
__shared__ Pair<T> shared_max[BlockSize / 32];
|
||||
__shared__ Pair<T> beam_max[TopPBeamTopK];
|
||||
__shared__ Pair<T> shared_max[BlockSize / 32];
|
||||
__shared__ Pair<T> beam_max[TopPBeamTopK];
|
||||
|
||||
Pair<T> topk[MaxLength];
|
||||
int beam = MaxLength;
|
||||
Pair<T> max;
|
||||
bool is_empty = false;
|
||||
bool firststep = true;
|
||||
__shared__ int count;
|
||||
Pair<T> topk[MaxLength];
|
||||
int beam = MaxLength;
|
||||
Pair<T> max;
|
||||
bool is_empty = false;
|
||||
bool firststep = true;
|
||||
__shared__ int count;
|
||||
|
||||
if (tid == 0) {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
for (int j = 0; j < MaxLength; j++) {
|
||||
topk[j].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
|
||||
while (top_num) {
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
|
||||
&beam,
|
||||
TopPBeamTopK,
|
||||
src + token_id * vocab_size,
|
||||
&firststep,
|
||||
&is_empty,
|
||||
&max,
|
||||
vocab_size,
|
||||
tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(shared_max,
|
||||
topk,
|
||||
beam_max,
|
||||
&beam,
|
||||
&top_num,
|
||||
&count,
|
||||
tid,
|
||||
wid,
|
||||
lane);
|
||||
}
|
||||
if (tid == 0) {
|
||||
float sum_prob = 0.0f;
|
||||
bool flag = false;
|
||||
for (int i = 0; i < TopPBeamTopK; i++) {
|
||||
out_id[token_id * max_cadidate_len + i] =
|
||||
static_cast<int64_t>(beam_max[i].id);
|
||||
out_val[token_id * max_cadidate_len + i] = beam_max[i].v;
|
||||
float val = static_cast<float>(beam_max[i].v);
|
||||
sum_prob += val;
|
||||
|
||||
if (sum_prob >= top_p_value) {
|
||||
actual_candidates_lens[token_id] = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (tid == 0) {
|
||||
count = 0;
|
||||
}
|
||||
|
||||
for (int j = 0; j < MaxLength; j++) {
|
||||
topk[j].set(std::numeric_limits<T>::min(), -1);
|
||||
}
|
||||
|
||||
while (top_num) {
|
||||
ThreadGetTopK<T, MaxLength, BlockSize>(topk,
|
||||
&beam,
|
||||
TopPBeamTopK,
|
||||
src + token_id * vocab_size,
|
||||
&firststep,
|
||||
&is_empty,
|
||||
&max,
|
||||
vocab_size,
|
||||
tid);
|
||||
BlockReduce<T, MaxLength, BlockSize>(
|
||||
shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane);
|
||||
}
|
||||
if (tid == 0) {
|
||||
float sum_prob = 0.0f;
|
||||
bool flag = false;
|
||||
for (int i = 0; i < TopPBeamTopK; i++) {
|
||||
out_id[token_id * max_cadidate_len + i] =
|
||||
static_cast<int64_t>(beam_max[i].id);
|
||||
out_val[token_id * max_cadidate_len + i] = beam_max[i].v;
|
||||
float val = static_cast<float>(beam_max[i].v);
|
||||
sum_prob += val;
|
||||
|
||||
if (sum_prob >= top_p_value) {
|
||||
actual_candidates_lens[token_id] = i + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0) {
|
||||
actual_candidates_lens[token_id] = max_cadidate_len;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TopKMaxLength>
|
||||
@@ -495,30 +488,28 @@ void DispatchTopK(const T* src,
|
||||
const int cadidate_len,
|
||||
const int max_seq_len,
|
||||
const cudaStream_t& stream) {
|
||||
int BlockSize = GetBlockSize(vocab_size);
|
||||
switch (cadidate_len) {
|
||||
FIXED_TOPK(switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(
|
||||
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
|
||||
<<<token_num, kBlockDim, 0, stream>>>(
|
||||
src,
|
||||
top_ps,
|
||||
output_padding_offset,
|
||||
out_id,
|
||||
out_val,
|
||||
actual_candidates_lens_data,
|
||||
vocab_size,
|
||||
cadidate_len,
|
||||
max_seq_len));
|
||||
default:
|
||||
PD_THROW(
|
||||
"Invalid max_candidate_len. Please set a value in [1,10] (step=1) "
|
||||
"or [10,100] (step=10)."
|
||||
);
|
||||
});
|
||||
default:
|
||||
PD_THROW("the input topk is not implemented.");
|
||||
}
|
||||
int BlockSize = GetBlockSize(vocab_size);
|
||||
switch (cadidate_len) {
|
||||
FIXED_TOPK(switch (BlockSize) {
|
||||
FIXED_BLOCK_DIM(
|
||||
KeMatrixTopPBeamTopKFt<T, TopKMaxLength, kTopK, kBlockDim>
|
||||
<<<token_num, kBlockDim, 0, stream>>>(src,
|
||||
top_ps,
|
||||
output_padding_offset,
|
||||
out_id,
|
||||
out_val,
|
||||
actual_candidates_lens_data,
|
||||
vocab_size,
|
||||
cadidate_len,
|
||||
max_seq_len));
|
||||
default:
|
||||
PD_THROW(
|
||||
"Invalid max_candidate_len. Please set a value in [1,10] (step=1) "
|
||||
"or [10,100] (step=10).");
|
||||
});
|
||||
default:
|
||||
PD_THROW("the input topk is not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
@@ -528,38 +519,38 @@ std::vector<paddle::Tensor> LaunchTopPCandidates(
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
const int candidates_len,
|
||||
const int max_seq_len) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
std::vector<int64_t> input_shape = probs.shape();
|
||||
const int token_num = input_shape[0];
|
||||
const int vocab_size = input_shape[1];
|
||||
std::vector<int64_t> input_shape = probs.shape();
|
||||
const int token_num = input_shape[0];
|
||||
const int vocab_size = input_shape[1];
|
||||
|
||||
auto verify_scores =
|
||||
paddle::full({token_num, candidates_len}, 0, D, probs.place());
|
||||
auto verify_tokens = paddle::full(
|
||||
{token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place());
|
||||
auto actual_candidate_lens =
|
||||
paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place());
|
||||
auto verify_scores =
|
||||
paddle::full({token_num, candidates_len}, 0, D, probs.place());
|
||||
auto verify_tokens = paddle::full(
|
||||
{token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place());
|
||||
auto actual_candidate_lens =
|
||||
paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place());
|
||||
|
||||
auto stream = probs.stream();
|
||||
auto stream = probs.stream();
|
||||
|
||||
constexpr int TopKMaxLength = 2;
|
||||
DispatchTopK<DataType_, TopKMaxLength>(
|
||||
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
|
||||
output_padding_offset.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
|
||||
actual_candidate_lens.data<int>(),
|
||||
vocab_size,
|
||||
token_num,
|
||||
candidates_len,
|
||||
max_seq_len,
|
||||
stream);
|
||||
constexpr int TopKMaxLength = 2;
|
||||
DispatchTopK<DataType_, TopKMaxLength>(
|
||||
reinterpret_cast<const DataType_*>(probs.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(top_p.data<data_t>()),
|
||||
output_padding_offset.data<int>(),
|
||||
verify_tokens.data<int64_t>(),
|
||||
reinterpret_cast<DataType_*>(verify_scores.data<data_t>()),
|
||||
actual_candidate_lens.data<int>(),
|
||||
vocab_size,
|
||||
token_num,
|
||||
candidates_len,
|
||||
max_seq_len,
|
||||
stream);
|
||||
|
||||
return {verify_scores, verify_tokens, actual_candidate_lens};
|
||||
return {verify_scores, verify_tokens, actual_candidate_lens};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
|
||||
@@ -568,37 +559,25 @@ std::vector<paddle::Tensor> DispatchTopPCandidatesWithDtype(
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
switch (probs.type()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
|
||||
probs,
|
||||
top_p,
|
||||
output_padding_offset,
|
||||
candidates_len,
|
||||
max_seq_len);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only bfloat16, float16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
switch (probs.type()) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::BFLOAT16>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT16>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
break;
|
||||
case paddle::DataType::FLOAT32:
|
||||
return LaunchTopPCandidates<paddle::DataType::FLOAT32>(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
break;
|
||||
default:
|
||||
PD_THROW(
|
||||
"NOT supported data type. "
|
||||
"Only bfloat16, float16 and float32 are supported. ");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> TopPCandidates(
|
||||
@@ -607,8 +586,8 @@ std::vector<paddle::Tensor> TopPCandidates(
|
||||
const paddle::Tensor& output_padding_offset,
|
||||
int candidates_len,
|
||||
int max_seq_len) {
|
||||
return DispatchTopPCandidatesWithDtype(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
return DispatchTopPCandidatesWithDtype(
|
||||
probs, top_p, output_padding_offset, candidates_len, max_seq_len);
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||
@@ -616,17 +595,17 @@ std::vector<std::vector<int64_t>> TopPCandidatesInferShape(
|
||||
const std::vector<int64_t>& top_p_shape,
|
||||
const std::vector<int64_t>& output_padding_offset_shape,
|
||||
int max_candidates_len) {
|
||||
int token_num = probs_shape[0];
|
||||
return {{token_num, max_candidates_len},
|
||||
{token_num, max_candidates_len},
|
||||
{token_num}};
|
||||
int token_num = probs_shape[0];
|
||||
return {{token_num, max_candidates_len},
|
||||
{token_num, max_candidates_len},
|
||||
{token_num}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> TopPCandidatesInferDtype(
|
||||
const paddle::DataType& probs_dtype,
|
||||
const paddle::DataType& top_p_dtype,
|
||||
const paddle::DataType& output_padding_offset_dtype) {
|
||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||
return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(top_p_candidates)
|
||||
|
||||
Reference in New Issue
Block a user