// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" // NOLINT #define WARP_SIZE 32 template __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) { return __shfl_down_sync(mask, val, static_cast(delta), width); } template <> __forceinline__ __device__ phi::dtype::float16 CudaShuffleDownSync( unsigned mask, phi::dtype::float16 val, int delta, int width) { return paddle::float16(__shfl_down_sync( mask, val.to_half(), static_cast(delta), width)); } template <> __forceinline__ __device__ phi::dtype::bfloat16 CudaShuffleDownSync( unsigned mask, phi::dtype::bfloat16 val, int delta, int width) { return paddle::bfloat16(__shfl_down_sync( mask, val.to_nv_bfloat16(), static_cast(delta), width)); } struct BlockPrefixCallbackOp { // Running prefix float running_total; // Constructor __device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {} // Callback operator to be entered by the first warp of threads in the // block. Thread-0 is responsible for returning a value for seeding the // block-wide scan. __device__ float operator()(float block_aggregate) { float old_prefix = running_total; running_total += block_aggregate; return old_prefix; } }; #define FINAL_MASK 0xFFFFFFFF #define FIXED_BLOCK_DIM_BASE(dim, ...) \ case (dim): { \ constexpr auto kBlockDim = (dim); \ __VA_ARGS__; \ } break #define FIXED_BLOCK_DIM(...) \ FIXED_BLOCK_DIM_BASE(1024, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) #define FIXED_TOPK_BASE(topk, ...) \ case (topk): { \ constexpr auto kTopK = topk; \ __VA_ARGS__; \ } break #define FIXED_TOPK(...) \ FIXED_TOPK_BASE(2, ##__VA_ARGS__); \ FIXED_TOPK_BASE(3, ##__VA_ARGS__); \ FIXED_TOPK_BASE(4, ##__VA_ARGS__); \ FIXED_TOPK_BASE(5, ##__VA_ARGS__); \ FIXED_TOPK_BASE(8, ##__VA_ARGS__); \ FIXED_TOPK_BASE(10, ##__VA_ARGS__) struct SegmentOffsetIter { explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {} __host__ __device__ __forceinline__ int operator()(int idx) const { return idx * num_cols_; } int num_cols_; }; inline int div_up(int a, int n) { return (a + n - 1) / n; } template __global__ void FillIndex(T* indices, T num_rows, T num_cols) { int col_id = threadIdx.x; int row_id = blockIdx.x; for (T j = row_id; j < num_rows; j += gridDim.x) { for (T i = col_id; i < num_cols; i += blockDim.x) { indices[j * num_cols + i] = i; } } } __global__ void SetCountIter(int* count_iter, int num) { int tid = threadIdx.x; int bid = blockIdx.x; int idx = bid * blockDim.x + tid; for (int i = idx; i < num; i += gridDim.x * blockDim.x) { count_iter[i] = i; } } template __global__ void top_p_candidates_kernel(T* sorted_probs, int64_t* sorted_id, T* out_val, int64_t* out_id, int* actual_candidates_lens, const int vocab_size, const float topp, const int candidates_len) { __shared__ int stop_shared; __shared__ float rand_p; const int tid = threadIdx.x; const int bid = blockIdx.x; constexpr int NUM_WARPS = BLOCK_SIZE / 32; const int lane_id = tid % 32; const int warp_id = tid / 32; typedef cub::BlockScan BlockScan; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockScan::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage_reduce; __shared__ uint32_t selected_shared[NUM_WARPS]; if (lane_id == 0) { selected_shared[warp_id] = 0; } // Initialize running total BlockPrefixCallbackOp prefix_op(0); __syncthreads(); int offset = bid * vocab_size; int end = ((vocab_size + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int i_activate = 0; float thread_offset = 0; for (int i = tid; i < end; i += BLOCK_SIZE) { float thread_count = (i < vocab_size) ? static_cast(sorted_probs[offset + i]) : 0.f; BlockScan(temp_storage) .InclusiveSum(thread_count, thread_offset, prefix_op); if (i < candidates_len) { out_id[bid * candidates_len + i] = sorted_id[offset + i]; out_val[bid * candidates_len + i] = sorted_probs[offset + i]; } uint32_t activate_mask = __ballot_sync(FINAL_MASK, topp <= thread_offset); i_activate = i; if (activate_mask != 0 || i >= candidates_len) { if (lane_id == 0) { atomicAdd(&stop_shared, 1); selected_shared[warp_id] = activate_mask; } } __syncthreads(); if (stop_shared > 0) { break; } } __syncthreads(); bool skip = (selected_shared[warp_id] > 0) ? false : true; for (int i = 0; i < warp_id; i++) { if (selected_shared[i] != 0) { // If the previous has stopped, skip the current warp skip = true; } } if (!skip) { int active_lane_id = WARP_SIZE - __popc(selected_shared[warp_id]); // first not 0 if (lane_id == active_lane_id) { actual_candidates_lens[bid] = i_activate + 1; } } __syncthreads(); if (tid == 0) { // printf("actual_candidates_lens[%d] %d\n", bid, // actual_candidates_lens[bid]); if (actual_candidates_lens[bid] == 0) { actual_candidates_lens[bid] = candidates_len; } } } template struct Pair { __device__ __forceinline__ Pair() {} __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {} __device__ __forceinline__ void set(T value, int id) { this->v = value; this->id = id; } __device__ __forceinline__ void operator=(const Pair& in) { v = in.v; id = in.id; } __device__ __forceinline__ bool operator<(const T value) const { return (static_cast(v) < static_cast(value)); } __device__ __forceinline__ bool operator>(const T value) const { return (static_cast(v) > static_cast(value)); } __device__ __forceinline__ bool operator<(const Pair& in) const { return (static_cast(v) < static_cast(in.v)) || ((static_cast(v) == static_cast(in.v)) && (id > in.id)); } __device__ __forceinline__ bool operator>(const Pair& in) const { return (static_cast(v) > static_cast(in.v)) || ((static_cast(v) == static_cast(in.v)) && (id < in.id)); } T v; int id; }; template __device__ __forceinline__ void AddTo(Pair topk[], const Pair& p, int beam_size) { for (int k = beam_size - 2; k >= 0; k--) { if (topk[k] < p) { topk[k + 1] = topk[k]; } else { topk[k + 1] = p; return; } } topk[0] = p; } template __device__ __forceinline__ void GetTopK( Pair topk[], const T* src, int idx, int dim, int beam_size) { while (idx < dim) { if (topk[beam_size - 1] < src[idx]) { Pair tmp(src[idx], idx); AddTo(topk, tmp, beam_size); } idx += BlockSize; } } template __device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx, int dim, const Pair& max, int beam_size) { while (idx < dim) { if (topk[beam_size - 1] < src[idx]) { Pair tmp(src[idx], idx); if (tmp < max) { AddTo(topk, tmp, beam_size); } } idx += BlockSize; } } template __device__ __forceinline__ void ThreadGetTopK(Pair topk[], int* beam, int beam_size, const T* src, bool* firstStep, bool* is_empty, Pair* max, int dim, const int tid) { if (*beam > 0) { int length = (*beam) < beam_size ? *beam : beam_size; if (*firstStep) { *firstStep = false; GetTopK(topk, src, tid, dim, length); } else { for (int k = 0; k < MaxLength; k++) { if (k < MaxLength - (*beam)) { topk[k] = topk[k + *beam]; } else { topk[k].set(std::numeric_limits::min(), -1); } } if (!(*is_empty)) { GetTopK( topk + MaxLength - *beam, src, tid, dim, *max, length); } } *max = topk[MaxLength - 1]; if ((*max).id == -1) *is_empty = true; *beam = 0; } } template __forceinline__ __device__ Pair WarpReduce(Pair input) { #pragma unroll for (int offset = 16; offset > 0; offset >>= 1) { T tmp_val = CudaShuffleDownSync(FINAL_MASK, input.v, offset); int tmp_id = CudaShuffleDownSync(FINAL_MASK, input.id, offset); if (static_cast(input.v) < static_cast(tmp_val)) { input.v = tmp_val; input.id = tmp_id; } } return input; } template __device__ __forceinline__ void BlockReduce(Pair shared_max[], Pair topk[], Pair beam_max[], int* beam, int* k, int* count, const int tid, const int wid, const int lane) { while (true) { __syncthreads(); Pair input_now = topk[0]; input_now = WarpReduce(input_now); if (lane == 0) { shared_max[wid] = input_now; } __syncthreads(); input_now = (tid < BlockSize / 32) ? shared_max[lane] : Pair(std::numeric_limits::min(), -1); if (wid == 0) { input_now = WarpReduce(input_now); if (lane == 0) shared_max[0] = input_now; } __syncthreads(); if (tid == 0) { beam_max[*count] = shared_max[0]; (*count)++; } int tid_max = shared_max[0].id % BlockSize; if (tid == tid_max) { (*beam)++; } if (--(*k) == 0) break; __syncthreads(); if (tid == tid_max) { if (*beam < MaxLength) { topk[0] = topk[*beam]; } } if (MaxLength < 5) { if (*beam >= MaxLength) break; } else { unsigned mask = 0u; mask = __ballot_sync(FINAL_MASK, true); if (tid_max / 32 == wid) { if (__shfl_down_sync(FINAL_MASK, *beam, tid_max % 32, 32) == MaxLength) break; } } } } template __global__ void KeMatrixTopPBeamTopKFt( const T* src, const T* top_ps, const int* output_padding_offset, int64_t* out_id, // [max_cadidate_len, 1] T* out_val, // [max_cadidate_len, 1] int* actual_candidates_lens, int vocab_size, const int max_cadidate_len, const int max_seq_len) { const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; const int token_id = blockIdx.x; const int ori_token_id = token_id + output_padding_offset[token_id]; const int bid = ori_token_id / max_seq_len; int top_num = TopPBeamTopK; float top_p_value = static_cast(top_ps[bid]); __shared__ Pair shared_max[BlockSize / 32]; __shared__ Pair beam_max[TopPBeamTopK]; Pair topk[MaxLength]; int beam = MaxLength; Pair max; bool is_empty = false; bool firststep = true; __shared__ int count; if (tid == 0) { count = 0; } for (int j = 0; j < MaxLength; j++) { topk[j].set(std::numeric_limits::min(), -1); } while (top_num) { ThreadGetTopK(topk, &beam, TopPBeamTopK, src + token_id * vocab_size, &firststep, &is_empty, &max, vocab_size, tid); BlockReduce(shared_max, topk, beam_max, &beam, &top_num, &count, tid, wid, lane); } if (tid == 0) { float sum_prob = 0.0f; bool flag = false; for (int i = 0; i < TopPBeamTopK; i++) { out_id[token_id * max_cadidate_len + i] = static_cast(beam_max[i].id); out_val[token_id * max_cadidate_len + i] = beam_max[i].v; float val = static_cast(beam_max[i].v); sum_prob += val; if (sum_prob >= top_p_value) { actual_candidates_lens[token_id] = i + 1; break; } } } } template void DispatchTopK(const T* src, const T* top_ps, const int* output_padding_offset, int64_t* out_id, // topk id T* out_val, // topk val int* actual_candidates_lens_data, const int vocab_size, const int token_num, const int cadidate_len, const int max_seq_len, const cudaStream_t& stream) { int BlockSize = GetBlockSize(vocab_size); switch (cadidate_len) { FIXED_TOPK(switch (BlockSize) { FIXED_BLOCK_DIM( KeMatrixTopPBeamTopKFt <<>>( src, top_ps, output_padding_offset, out_id, out_val, actual_candidates_lens_data, vocab_size, cadidate_len, max_seq_len)); default: PD_THROW( "the input data shape has error in the topp_beam_topk " "kernel."); }); default: PD_THROW("the input topk is not implemented."); } } template std::vector LaunchTopPCandidates( const paddle::Tensor& probs, // [token_num, vocab_size] const paddle::Tensor& top_p, // [token_num] const paddle::Tensor& output_padding_offset, const int candidates_len, const int max_seq_len) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; std::vector input_shape = probs.shape(); const int token_num = input_shape[0]; const int vocab_size = input_shape[1]; auto verify_scores = paddle::full({token_num, candidates_len}, 0, D, probs.place()); auto verify_tokens = paddle::full( {token_num, candidates_len}, 0, paddle::DataType::INT64, probs.place()); auto actual_candidate_lens = paddle::full({token_num}, 0, paddle::DataType::INT32, probs.place()); auto stream = probs.stream(); constexpr int TopKMaxLength = 2; DispatchTopK( reinterpret_cast(probs.data()), reinterpret_cast(top_p.data()), output_padding_offset.data(), verify_tokens.data(), reinterpret_cast(verify_scores.data()), actual_candidate_lens.data(), vocab_size, token_num, candidates_len, max_seq_len, stream); return {verify_scores, verify_tokens, actual_candidate_lens}; } std::vector DispatchTopPCandidatesWithDtype( const paddle::Tensor& probs, const paddle::Tensor& top_p, const paddle::Tensor& output_padding_offset, int candidates_len, int max_seq_len) { switch (probs.type()) { case paddle::DataType::BFLOAT16: return LaunchTopPCandidates( probs, top_p, output_padding_offset, candidates_len, max_seq_len); break; case paddle::DataType::FLOAT16: return LaunchTopPCandidates( probs, top_p, output_padding_offset, candidates_len, max_seq_len); break; case paddle::DataType::FLOAT32: return LaunchTopPCandidates( probs, top_p, output_padding_offset, candidates_len, max_seq_len); break; default: PD_THROW( "NOT supported data type. " "Only bfloat16, float16 and float32 are supported. "); break; } } std::vector TopPCandidates( const paddle::Tensor& probs, const paddle::Tensor& top_p, const paddle::Tensor& output_padding_offset, int candidates_len, int max_seq_len) { return DispatchTopPCandidatesWithDtype( probs, top_p, output_padding_offset, candidates_len, max_seq_len); } std::vector> TopPCandidatesInferShape( const std::vector& probs_shape, const std::vector& top_p_shape, const std::vector& output_padding_offset_shape, int max_candidates_len) { int token_num = probs_shape[0]; return {{token_num, max_candidates_len}, {token_num, max_candidates_len}, {token_num}}; } std::vector TopPCandidatesInferDtype( const paddle::DataType& probs_dtype, const paddle::DataType& top_p_dtype, const paddle::DataType& output_padding_offset_dtype) { return {probs_dtype, paddle::DataType::INT64, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(top_p_candidates) .Inputs({"probs", "top_p", "output_padding_offset"}) .Outputs({"verify_scores", "verify_tokens", "actual_candidate_lens"}) .Attrs({"candidates_len: int", "max_seq_len: int"}) .SetKernelFn(PD_KERNEL(TopPCandidates)) .SetInferShapeFn(PD_INFER_SHAPE(TopPCandidatesInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(TopPCandidatesInferDtype));