From 8545b705ed797823caa2e314751c5bc146fcf95d Mon Sep 17 00:00:00 2001 From: GoldPancake <56388518+Deleter-D@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:01:05 +0800 Subject: [PATCH] fix top_p_candidates (#5400) Co-authored-by: freeliuzc --- .../speculate_decoding/top_p_candidates.cu | 799 +++++++++--------- 1 file changed, 389 insertions(+), 410 deletions(-) diff --git a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu index 0fced697d..bcafd45d0 100644 --- a/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu +++ b/custom_ops/gpu_ops/speculate_decoding/top_p_candidates.cu @@ -19,113 +19,113 @@ template __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { - return __shfl_down_sync(mask, val, static_cast(delta), width); + return __shfl_down_sync(mask, val, static_cast(delta), width); } template <> __forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( unsigned mask, phi::dtype::float16 val, int delta, int width) { - return paddle::float16(__shfl_down_sync( - mask, val.to_half(), static_cast(delta), width)); + return paddle::float16(__shfl_down_sync( + mask, val.to_half(), static_cast(delta), width)); } template <> __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { - return paddle::bfloat16(__shfl_down_sync( - mask, val.to_nv_bfloat16(), static_cast(delta), width)); + return paddle::bfloat16(__shfl_down_sync( + mask, val.to_nv_bfloat16(), static_cast(delta), width)); } struct BlockPrefixCallbackOp { - // Running prefix - float running_total; - // Constructor - __device__ BlockPrefixCallbackOp(float running_total) - : running_total(running_total) {} - // Callback operator to be entered by the first warp of threads in the - // block. Thread-0 is responsible for returning a value for seeding the - // block-wide scan. - __device__ float operator()(float block_aggregate) { - float old_prefix = running_total; - running_total += block_aggregate; - return old_prefix; - } + // Running prefix + float running_total; + // Constructor + __device__ BlockPrefixCallbackOp(float running_total) + : running_total(running_total) {} + // Callback operator to be entered by the first warp of threads in the + // block. Thread-0 is responsible for returning a value for seeding the + // block-wide scan. + __device__ float operator()(float block_aggregate) { + float old_prefix = running_total; + running_total += block_aggregate; + return old_prefix; + } }; #define FINAL_MASK 0xFFFFFFFF -#define FIXED_BLOCK_DIM_BASE(dim, ...) \ - case (dim): { \ - constexpr auto kBlockDim = (dim); \ - __VA_ARGS__; \ - } break +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break -#define FIXED_BLOCK_DIM(...) \ - FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) -#define FIXED_TOPK_BASE(topk, ...) \ - case (topk): { \ - constexpr auto kTopK = topk; \ - __VA_ARGS__; \ - } break +#define FIXED_TOPK_BASE(topk, ...) \ + case (topk): { \ + constexpr auto kTopK = topk; \ + __VA_ARGS__; \ + } break -#define FIXED_TOPK(...) \ - FIXED_TOPK_BASE(1, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(6, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(7, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(9, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(10, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(20, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(30, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(40, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(50, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(60, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(70, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(80, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(90, ##__VA_ARGS__); \ - FIXED_TOPK_BASE(100, ##__VA_ARGS__); +#define FIXED_TOPK(...) \ + FIXED_TOPK_BASE(1, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(6, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(7, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(9, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(10, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(20, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(30, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(40, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(50, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(60, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(70, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(80, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(90, ##__VA_ARGS__); \ + FIXED_TOPK_BASE(100, ##__VA_ARGS__); struct SegmentOffsetIter { - explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} + explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} - __host__ __device__ __forceinline__ int operator()(int idx) const { - return idx * num_cols_; - } + __host__ __device__ __forceinline__ int operator()(int idx) const { + return idx * num_cols_; + } - int num_cols_; + int num_cols_; }; inline int div_up(int a, int n) { return (a + n - 1) / n; } template __global__ void FillIndex(T* indices, T num_rows, T num_cols) { - int col_id = threadIdx.x; - int row_id = blockIdx.x; + int col_id = threadIdx.x; + int row_id = blockIdx.x; - for (T j = row_id; j < num_rows; j += gridDim.x) { - for (T i = col_id; i < num_cols; i += blockDim.x) { - indices[j * num_cols + i] = i; - } + for (T j = row_id; j < num_rows; j += gridDim.x) { + for (T i = col_id; i < num_cols; i += blockDim.x) { + indices[j * num_cols + i] = i; } + } } __global__ void SetCountIter(int* count_iter, int num) { - int tid = threadIdx.x; - int bid = blockIdx.x; - int idx = bid * blockDim.x + tid; - for (int i = idx; i < num; i += gridDim.x * blockDim.x) { - count_iter[i] = i; - } + int tid = threadIdx.x; + int bid = blockIdx.x; + int idx = bid * blockDim.x + tid; + for (int i = idx; i < num; i += gridDim.x * blockDim.x) { + count_iter[i] = i; + } } template @@ -137,148 +137,146 @@ __global__ void top_p_candidates_kernel(T* sorted_probs, const int vocab_size, const float topp, const int candidates_len) { - __shared__ int stop_shared; - __shared__ float rand_p; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - constexpr int NUM_WARPS = BLOCK_SIZE / 32; - const int lane_id = tid % 32; - const int warp_id = tid / 32; + __shared__ int stop_shared; + __shared__ float rand_p; + const int tid = threadIdx.x; + const int bid = blockIdx.x; + constexpr int NUM_WARPS = BLOCK_SIZE / 32; + const int lane_id = tid % 32; + const int warp_id = tid / 32; - typedef cub::BlockScan BlockScan; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ typename BlockReduce::TempStorage temp_storage_reduce; - __shared__ uint32_t selected_shared[NUM_WARPS]; + typedef cub::BlockScan BlockScan; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ typename BlockReduce::TempStorage temp_storage_reduce; + __shared__ uint32_t selected_shared[NUM_WARPS]; - if (lane_id == 0) { - selected_shared[warp_id] = 0; + if (lane_id == 0) { + selected_shared[warp_id] = 0; + } + + // Initialize running total + BlockPrefixCallbackOp prefix_op(0); + + __syncthreads(); + + int offset = bid * vocab_size; + int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int i_activate = 0; + float thread_offset = 0; + for (int i = tid; i < end; i += BLOCK_SIZE) { + float thread_count = + (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; + + BlockScan(temp_storage) + .InclusiveSum(thread_count, thread_offset, prefix_op); + + if (i < candidates_len) { + out_id[bid * candidates_len + i] = sorted_id[offset + i]; + out_val[bid * candidates_len + i] = sorted_probs[offset + i]; } - // Initialize running total - BlockPrefixCallbackOp prefix_op(0); - - __syncthreads(); - - int offset = bid * vocab_size; - int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int i_activate = 0; - float thread_offset = 0; - for (int i = tid; i < end; i += BLOCK_SIZE) { - float thread_count = (i < vocab_size) - ? static_cast(sorted_probs[offset + i]) - : 0.f; - - BlockScan(temp_storage) - .InclusiveSum(thread_count, thread_offset, prefix_op); - - if (i < candidates_len) { - out_id[bid * candidates_len + i] = sorted_id[offset + i]; - out_val[bid * candidates_len + i] = sorted_probs[offset + i]; - } - - uint32_t activate_mask = - __ballot_sync(FINAL_MASK, topp <= thread_offset); - i_activate = i; - if (activate_mask != 0 || i >= candidates_len) { - if (lane_id == 0) { - atomicAdd(&stop_shared, 1); - selected_shared[warp_id] = activate_mask; - } - } - __syncthreads(); - if (stop_shared > 0) { - break; - } + uint32_t activate_mask = __ballot_sync(FINAL_MASK, topp <= thread_offset); + i_activate = i; + if (activate_mask != 0 || i >= candidates_len) { + if (lane_id == 0) { + atomicAdd(&stop_shared, 1); + selected_shared[warp_id] = activate_mask; + } } __syncthreads(); - bool skip = (selected_shared[warp_id] > 0) ? false : true; - for (int i = 0; i < warp_id; i++) { - if (selected_shared[i] != 0) { - // If the previous has stopped, skip the current warp - skip = true; - } + if (stop_shared > 0) { + break; } - if (!skip) { - int active_lane_id = - WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 - if (lane_id == active_lane_id) { - actual_candidates_lens[bid] = i_activate + 1; - } + } + __syncthreads(); + bool skip = (selected_shared[warp_id] > 0) ? false : true; + for (int i = 0; i < warp_id; i++) { + if (selected_shared[i] != 0) { + // If the previous has stopped, skip the current warp + skip = true; } - __syncthreads(); - if (tid == 0) { - // printf("actual_candidates_lens[%d] %d\n", bid, - // actual_candidates_lens[bid]); - if (actual_candidates_lens[bid] == 0) { - actual_candidates_lens[bid] = candidates_len; - } + } + if (!skip) { + int active_lane_id = + WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 + if (lane_id == active_lane_id) { + actual_candidates_lens[bid] = i_activate + 1; } + } + __syncthreads(); + if (tid == 0) { + // printf("actual_candidates_lens[%d] %d\n", bid, + // actual_candidates_lens[bid]); + if (actual_candidates_lens[bid] == 0) { + actual_candidates_lens[bid] = candidates_len; + } + } } template struct Pair { - __device__ __forceinline__ Pair() {} - __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} + __device__ __forceinline__ Pair() {} + __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} - __device__ __forceinline__ void set(T value, int id) { - this->v = value; - this->id = id; - } + __device__ __forceinline__ void set(T value, int id) { + this->v = value; + this->id = id; + } - __device__ __forceinline__ void operator=(const Pair& in) { - v = in.v; - id = in.id; - } + __device__ __forceinline__ void operator=(const Pair& in) { + v = in.v; + id = in.id; + } - __device__ __forceinline__ bool operator<(const T value) const { - return (static_cast(v) < static_cast(value)); - } + __device__ __forceinline__ bool operator<(const T value) const { + return (static_cast(v) < static_cast(value)); + } - __device__ __forceinline__ bool operator>(const T value) const { - return (static_cast(v) > static_cast(value)); - } - __device__ __forceinline__ bool operator<(const Pair& in) const { - return (static_cast(v) < static_cast(in.v)) || - ((static_cast(v) == static_cast(in.v)) && - (id > in.id)); - } + __device__ __forceinline__ bool operator>(const T value) const { + return (static_cast(v) > static_cast(value)); + } + __device__ __forceinline__ bool operator<(const Pair& in) const { + return (static_cast(v) < static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id > in.id)); + } - __device__ __forceinline__ bool operator>(const Pair& in) const { - return (static_cast(v) > static_cast(in.v)) || - ((static_cast(v) == static_cast(in.v)) && - (id < in.id)); - } + __device__ __forceinline__ bool operator>(const Pair& in) const { + return (static_cast(v) > static_cast(in.v)) || + ((static_cast(v) == static_cast(in.v)) && + (id < in.id)); + } - T v; - int id; + T v; + int id; }; template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { - for (int k = beam_size - 2; k >= 0; k--) { - if (topk[k] < p) { - topk[k + 1] = topk[k]; - } else { - topk[k + 1] = p; - return; - } + for (int k = beam_size - 2; k >= 0; k--) { + if (topk[k] < p) { + topk[k + 1] = topk[k]; + } else { + topk[k + 1] = p; + return; } - topk[0] = p; + } + topk[0] = p; } template __device__ __forceinline__ void GetTopK( Pair topk[], const T* src, int idx, int dim, int beam_size) { - while (idx < dim) { - if (topk[beam_size - 1] < src[idx]) { - Pair tmp(src[idx], idx); - AddTo(topk, tmp, beam_size); - } - idx += BlockSize; + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + AddTo(topk, tmp, beam_size); } + idx += BlockSize; + } } template @@ -288,15 +286,15 @@ __device__ __forceinline__ void GetTopK(Pair topk[], int dim, const Pair& max, int beam_size) { - while (idx < dim) { - if (topk[beam_size - 1] < src[idx]) { - Pair tmp(src[idx], idx); - if (tmp < max) { - AddTo(topk, tmp, beam_size); - } - } - idx += BlockSize; + while (idx < dim) { + if (topk[beam_size - 1] < src[idx]) { + Pair tmp(src[idx], idx); + if (tmp < max) { + AddTo(topk, tmp, beam_size); + } } + idx += BlockSize; + } } template @@ -309,43 +307,43 @@ __device__ __forceinline__ void ThreadGetTopK(Pair topk[], Pair* max, int dim, const int tid) { - if (*beam > 0) { - int length = (*beam) < beam_size ? *beam : beam_size; - if (*firstStep) { - *firstStep = false; - GetTopK(topk, src, tid, dim, length); + if (*beam > 0) { + int length = (*beam) < beam_size ? *beam : beam_size; + if (*firstStep) { + *firstStep = false; + GetTopK(topk, src, tid, dim, length); + } else { + for (int k = 0; k < MaxLength; k++) { + if (k < MaxLength - (*beam)) { + topk[k] = topk[k + *beam]; } else { - for (int k = 0; k < MaxLength; k++) { - if (k < MaxLength - (*beam)) { - topk[k] = topk[k + *beam]; - } else { - topk[k].set(std::numeric_limits::min(), -1); - } - } - if (!(*is_empty)) { - GetTopK( - topk + MaxLength - *beam, src, tid, dim, *max, length); - } + topk[k].set(std::numeric_limits::min(), -1); } - - *max = topk[MaxLength - 1]; - if ((*max).id == -1) *is_empty = true; - *beam = 0; + } + if (!(*is_empty)) { + GetTopK( + topk + MaxLength - *beam, src, tid, dim, *max, length); + } } + + *max = topk[MaxLength - 1]; + if ((*max).id == -1) *is_empty = true; + *beam = 0; + } } template __forceinline__ __device__ Pair WarpReduce(Pair input) { #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { - T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset); - int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset); - if (static_cast(input.v) < static_cast(tmp_val)) { - input.v = tmp_val; - input.id = tmp_id; - } + for (int offset = 16; offset > 0; offset >>= 1) { + T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset); + int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset); + if (static_cast(input.v) < static_cast(tmp_val)) { + input.v = tmp_val; + input.id = tmp_id; } - return input; + } + return input; } template @@ -358,52 +356,51 @@ __device__ __forceinline__ void BlockReduce(Pair shared_max[], const int tid, const int wid, const int lane) { - while (true) { - __syncthreads(); - Pair input_now = topk[0]; - input_now = WarpReduce(input_now); + while (true) { + __syncthreads(); + Pair input_now = topk[0]; + input_now = WarpReduce(input_now); - if (lane == 0) { - shared_max[wid] = input_now; - } - __syncthreads(); - input_now = (tid < BlockSize / 32) - ? shared_max[lane] - : Pair(std::numeric_limits::min(), -1); - if (wid == 0) { - input_now = WarpReduce(input_now); - if (lane == 0) shared_max[0] = input_now; - } - __syncthreads(); - if (tid == 0) { - beam_max[*count] = shared_max[0]; - (*count)++; - } - int tid_max = shared_max[0].id % BlockSize; - if (tid == tid_max) { - (*beam)++; - } - if (--(*k) == 0) break; - __syncthreads(); - - if (tid == tid_max) { - if (*beam < MaxLength) { - topk[0] = topk[*beam]; - } - } - - if (MaxLength < 5) { - if (*beam >= MaxLength) break; - } else { - unsigned mask = 0u; - mask = __ballot_sync(FINAL_MASK, true); - if (tid_max / 32 == wid) { - if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == - MaxLength) - break; - } - } + if (lane == 0) { + shared_max[wid] = input_now; } + __syncthreads(); + input_now = (tid < BlockSize / 32) + ? shared_max[lane] + : Pair(std::numeric_limits::min(), -1); + if (wid == 0) { + input_now = WarpReduce(input_now); + if (lane == 0) shared_max[0] = input_now; + } + __syncthreads(); + if (tid == 0) { + beam_max[*count] = shared_max[0]; + (*count)++; + } + int tid_max = shared_max[0].id % BlockSize; + if (tid == tid_max) { + (*beam)++; + } + if (--(*k) == 0) break; + __syncthreads(); + + if (tid == tid_max) { + if (*beam < MaxLength) { + topk[0] = topk[*beam]; + } + } + + if (MaxLength < 5) { + if (*beam >= MaxLength) break; + } else { + unsigned mask = 0u; + mask = __ballot_sync(FINAL_MASK, true); + if (tid_max / 32 == wid) { + if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) + break; + } + } + } } template @@ -417,70 +414,66 @@ __global__ void KeMatrixTopPBeamTopKFt( int vocab_size, const int max_cadidate_len, const int max_seq_len) { - const int tid = threadIdx.x; - const int wid = tid / 32; - const int lane = tid % 32; - const int token_id = blockIdx.x; - const int ori_token_id = token_id + output_padding_offset[token_id]; - const int bid = ori_token_id / max_seq_len; + const int tid = threadIdx.x; + const int wid = tid / 32; + const int lane = tid % 32; + const int token_id = blockIdx.x; + const int ori_token_id = token_id + output_padding_offset[token_id]; + const int bid = ori_token_id / max_seq_len; - int top_num = TopPBeamTopK; - float top_p_value = static_cast(top_ps[bid]); + int top_num = TopPBeamTopK; + float top_p_value = static_cast(top_ps[bid]); - __shared__ Pair shared_max[BlockSize / 32]; - __shared__ Pair beam_max[TopPBeamTopK]; + __shared__ Pair shared_max[BlockSize / 32]; + __shared__ Pair beam_max[TopPBeamTopK]; - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; - __shared__ int count; + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + __shared__ int count; - if (tid == 0) { - count = 0; - } - - for (int j = 0; j < MaxLength; j++) { - topk[j].set(std::numeric_limits::min(), -1); - } - - while (top_num) { - ThreadGetTopK(topk, - &beam, - TopPBeamTopK, - src + token_id * vocab_size, - &firststep, - &is_empty, - &max, - vocab_size, - tid); - BlockReduce(shared_max, - topk, - beam_max, - &beam, - &top_num, - &count, - tid, - wid, - lane); - } - if (tid == 0) { - float sum_prob = 0.0f; - bool flag = false; - for (int i = 0; i < TopPBeamTopK; i++) { - out_id[token_id * max_cadidate_len + i] = - static_cast(beam_max[i].id); - out_val[token_id * max_cadidate_len + i] = beam_max[i].v; - float val = static_cast(beam_max[i].v); - sum_prob += val; - - if (sum_prob >= top_p_value) { - actual_candidates_lens[token_id] = i + 1; - break; - } - } + if (tid == 0) { + count = 0; + } + + for (int j = 0; j < MaxLength; j++) { + topk[j].set(std::numeric_limits::min(), -1); + } + + while (top_num) { + ThreadGetTopK(topk, + &beam, + TopPBeamTopK, + src + token_id * vocab_size, + &firststep, + &is_empty, + &max, + vocab_size, + tid); + BlockReduce( + shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); + } + if (tid == 0) { + float sum_prob = 0.0f; + bool flag = false; + for (int i = 0; i < TopPBeamTopK; i++) { + out_id[token_id * max_cadidate_len + i] = + static_cast(beam_max[i].id); + out_val[token_id * max_cadidate_len + i] = beam_max[i].v; + float val = static_cast(beam_max[i].v); + sum_prob += val; + + if (sum_prob >= top_p_value) { + actual_candidates_lens[token_id] = i + 1; + break; + } } + } + if (top_p_value == 1.0 && actual_candidates_lens[token_id] == 0) { + actual_candidates_lens[token_id] = max_cadidate_len; + } } template @@ -495,30 +488,28 @@ void DispatchTopK(const T* src, const int cadidate_len, const int max_seq_len, const cudaStream_t& stream) { - int BlockSize = GetBlockSize(vocab_size); - switch (cadidate_len) { - FIXED_TOPK(switch (BlockSize) { - FIXED_BLOCK_DIM( - KeMatrixTopPBeamTopKFt - <<>>( - src, - top_ps, - output_padding_offset, - out_id, - out_val, - actual_candidates_lens_data, - vocab_size, - cadidate_len, - max_seq_len)); - default: - PD_THROW( - "Invalid max_candidate_len. Please set a value in [1,10] (step=1) " - "or [10,100] (step=10)." - ); - }); - default: - PD_THROW("the input topk is not implemented."); - } + int BlockSize = GetBlockSize(vocab_size); + switch (cadidate_len) { + FIXED_TOPK(switch (BlockSize) { + FIXED_BLOCK_DIM( + KeMatrixTopPBeamTopKFt + <<>>(src, + top_ps, + output_padding_offset, + out_id, + out_val, + actual_candidates_lens_data, + vocab_size, + cadidate_len, + max_seq_len)); + default: + PD_THROW( + "Invalid max_candidate_len. Please set a value in [1,10] (step=1) " + "or [10,100] (step=10)."); + }); + default: + PD_THROW("the input topk is not implemented."); + } } template @@ -528,38 +519,38 @@ std::vector LaunchTopPCandidates( const paddle::Tensor& output_padding_offset, const int candidates_len, const int max_seq_len) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; - std::vector input_shape = probs.shape(); - const int token_num = input_shape[0]; - const int vocab_size = input_shape[1]; + std::vector input_shape = probs.shape(); + const int token_num = input_shape[0]; + const int vocab_size = input_shape[1]; - auto verify_scores = - paddle::full({token_num, candidates_len}, 0, D, probs.place()); - auto verify_tokens = paddle::full( - {token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place()); - auto actual_candidate_lens = - paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place()); + auto verify_scores = + paddle::full({token_num, candidates_len}, 0, D, probs.place()); + auto verify_tokens = paddle::full( + {token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place()); + auto actual_candidate_lens = + paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place()); - auto stream = probs.stream(); + auto stream = probs.stream(); - constexpr int TopKMaxLength = 2; - DispatchTopK( - reinterpret_cast(probs.data()), - reinterpret_cast(top_p.data()), - output_padding_offset.data(), - verify_tokens.data(), - reinterpret_cast(verify_scores.data()), - actual_candidate_lens.data(), - vocab_size, - token_num, - candidates_len, - max_seq_len, - stream); + constexpr int TopKMaxLength = 2; + DispatchTopK( + reinterpret_cast(probs.data()), + reinterpret_cast(top_p.data()), + output_padding_offset.data(), + verify_tokens.data(), + reinterpret_cast(verify_scores.data()), + actual_candidate_lens.data(), + vocab_size, + token_num, + candidates_len, + max_seq_len, + stream); - return {verify_scores, verify_tokens, actual_candidate_lens}; + return {verify_scores, verify_tokens, actual_candidate_lens}; } std::vector DispatchTopPCandidatesWithDtype( @@ -568,37 +559,25 @@ std::vector DispatchTopPCandidatesWithDtype( const paddle::Tensor& output_padding_offset, int candidates_len, int max_seq_len) { - switch (probs.type()) { - case paddle::DataType::BFLOAT16: - return LaunchTopPCandidates( - probs, - top_p, - output_padding_offset, - candidates_len, - max_seq_len); - break; - case paddle::DataType::FLOAT16: - return LaunchTopPCandidates( - probs, - top_p, - output_padding_offset, - candidates_len, - max_seq_len); - break; - case paddle::DataType::FLOAT32: - return LaunchTopPCandidates( - probs, - top_p, - output_padding_offset, - candidates_len, - max_seq_len); - break; - default: - PD_THROW( - "NOT supported data type. " - "Only bfloat16, float16 and float32 are supported. "); - break; - } + switch (probs.type()) { + case paddle::DataType::BFLOAT16: + return LaunchTopPCandidates( + probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + case paddle::DataType::FLOAT16: + return LaunchTopPCandidates( + probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + case paddle::DataType::FLOAT32: + return LaunchTopPCandidates( + probs, top_p, output_padding_offset, candidates_len, max_seq_len); + break; + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16, float16 and float32 are supported. "); + break; + } } std::vector TopPCandidates( @@ -607,8 +586,8 @@ std::vector TopPCandidates( const paddle::Tensor& output_padding_offset, int candidates_len, int max_seq_len) { - return DispatchTopPCandidatesWithDtype( - probs, top_p, output_padding_offset, candidates_len, max_seq_len); + return DispatchTopPCandidatesWithDtype( + probs, top_p, output_padding_offset, candidates_len, max_seq_len); } std::vector> TopPCandidatesInferShape( @@ -616,17 +595,17 @@ std::vector> TopPCandidatesInferShape( const std::vector& top_p_shape, const std::vector& output_padding_offset_shape, int max_candidates_len) { - int token_num = probs_shape[0]; - return {{token_num, max_candidates_len}, - {token_num, max_candidates_len}, - {token_num}}; + int token_num = probs_shape[0]; + return {{token_num, max_candidates_len}, + {token_num, max_candidates_len}, + {token_num}}; } std::vector TopPCandidatesInferDtype( const paddle::DataType& probs_dtype, const paddle::DataType& top_p_dtype, const paddle::DataType& output_padding_offset_dtype) { - return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; + return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(top_p_candidates)