// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifdef __NVCC__ #include #endif #ifdef __HIPCC__ #include namespace cub = hipcub; #endif #include #include #include #include #include #include #include #include #include #include "stdint.h" #include "helper.h" #define FLT_MAX 1e38 static constexpr int kBlockSizeForSmallBeamWidth = 256; static constexpr int kMaxVocabPartForStage1FastKernel = 128; #define CASE_K(K) \ case K: \ invokeTopKSoftMaxLauncher( \ params, beam_group_idx, stream); \ break #define DISPATCH_COMPUTE_PARTS_K(K) \ case K: \ ComputeVocParts(params); \ break template struct BeamSearchParams { // Scalar values int batch_size{0}; int beam_width{0}; int beam_group_size{0}; int beam_group_idx{0}; int vocab_size{0}; int dec_stride{0}; int max_seq_len{0}; int end_ids_len{0}; bool fuse_softmax{true}; bool early_stop{false}; int voc_parts{0}; bool use_fast_kernel{true}; int max_smem_per_block{0}; T *logits{nullptr}; const int *step_ids{nullptr}; // [BS * BM, 1] const int *seq_lens{nullptr}; // [BS * BM, 1] const int *max_dec_lens{nullptr}; const int *end_ids{nullptr}; const T *cum_scores{nullptr}; const int *block_tables{nullptr}; const int *beam_cache_ids{nullptr}; const float *length_penalty{nullptr}; // [BS, 1] const float *diversity_penalty{nullptr}; // [BS, 1] bool *stop_flags{nullptr}; // [BS, 1] int *cache_ids_out{nullptr}; // [BS * BM, max_dec_len] bool *beam_finished{nullptr}; // [BS * BM, 1] int *block_tables_out{nullptr}; // [BS * BM, max_seq_len] T *cum_scores_out{nullptr}; // [BS * BM, 1] int *beam_hyps_out{nullptr}; // [BS * BM, max_dec_len] T *beam_hyps_score_out{nullptr}; // [BS * BM, 1] // func out int *next_tokens{nullptr}; int *parent_ids{nullptr}; // workspace int *tmp_ids{nullptr}; T *tmp_vals{nullptr}; T *tmp_buffer{nullptr}; }; template ::value>, typename = std::enable_if_t::value>> auto constexpr ceilDiv(T numerator, U denominator) { return (numerator + denominator - 1) / denominator; } __device__ bool is_in_end(const int id, const int *end_ids, int length) { bool flag = false; for (int i = 0; i < length; i++) { if (id == end_ids[i]) { return true; } } return flag; } template __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) { // score = log(prob) / (length)^length_penalty. if (length_penalty == 0.0f || length == 1) { return log_prob; } return log_prob / static_cast(powf(length, length_penalty)); } // <<>> template __global__ void apply_group_diversity_penalty(BeamSearchParams params, const int batch_size, const int beam_width, const int beam_group_idx, const int vocab_size) { const int beam_group_size = K / 2; const int batch_idx = blockIdx.x; const int beam_group_sub_idx = threadIdx.x; const bool *beam_finished = params.beam_finished + batch_idx * beam_width; T *logtis = params.logits + batch_idx * beam_width * vocab_size + beam_group_idx * beam_group_size * vocab_size + beam_group_sub_idx * vocab_size; int *next_tokens = params.next_tokens + batch_idx * beam_width; // apply previous group token ids penalty #pragma unroll for (int token_idx = 0; token_idx < beam_group_idx * beam_group_size; ++token_idx) { const bool finished = beam_finished[token_idx]; if (!finished) { const int token_id = next_tokens[token_idx]; logtis[token_id] -= params.diversity_penalty[batch_idx]; } } } struct DySoftMaxStruct { float logit; float score; }; __device__ __forceinline__ DySoftMaxStruct reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) { bool a_bigger = (a.logit > b.logit); DySoftMaxStruct bigger_m = a_bigger ? a : b; DySoftMaxStruct smaller_m = a_bigger ? b : a; DySoftMaxStruct res; res.score = bigger_m.score + smaller_m.score * expf(smaller_m.logit - bigger_m.logit); res.logit = bigger_m.logit; return res; } template struct BeamHypothesis { T score; int *seq; int seq_len; __device__ __forceinline__ void init(int *_seq, T _score, const int _max_seq_len) { seq = _seq; score = _score; seq_len = _max_seq_len; } }; template struct BeamHypothesesTopK { BeamHypothesis hyps[K]; int max_dec_len; __device__ __forceinline__ void init(int *_beam_hyps, T *_beam_hyps_score, const int _max_dec_len) { max_dec_len = _max_dec_len; for (int i = 0; i < K; i++) { // 使用默认构造函数创建默认的 BeamHypothesis 对象 hyps[i].init( _beam_hyps + i * _max_dec_len, _beam_hyps_score[i], _max_dec_len); } } __device__ void insert(const int *token_ids, int step, int cur_token_id, T score) { if (score > get_worst_score()) { for (int i = 0; i < step; i++) { hyps[K - 1].seq[i] = token_ids[i]; } hyps[K - 1].seq[step] = cur_token_id; hyps[K - 1].score = score; for (int k = K - 2; k >= 0; --k) { if (hyps[k + 1].score > hyps[k].score) { T tmp_score = hyps[k].score; hyps[k].score = hyps[k + 1].score; hyps[k + 1].score = tmp_score; int tmp_val; for (int i = 0; i <= step && (hyps[k + 1].seq[i] > 0 || hyps[k].seq[i] > 0); i++) { tmp_val = hyps[k + 1].seq[i]; hyps[k + 1].seq[i] = hyps[k].seq[i]; hyps[k].seq[i] = tmp_val; } } } } } __device__ __forceinline__ T get_worst_score() { return hyps[K - 1].score; } }; template struct TopK { int ids[K]; T vals[K]; int parent_ids[K]; __device__ __forceinline__ void insert(T elem, int elem_id) { if (elem > vals[K - 1] || (ids[K - 1] == -1) || ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { vals[K - 1] = elem; ids[K - 1] = elem_id; } for (int k = K - 2; k >= 0; --k) { if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { T tmp_val = vals[k]; int tmp_id = ids[k]; vals[k] = vals[k + 1]; ids[k] = ids[k + 1]; vals[k + 1] = tmp_val; ids[k + 1] = tmp_id; } } } __device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) { if (elem > vals[K - 1] || (ids[K - 1] == -1) || ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) { vals[K - 1] = elem; ids[K - 1] = elem_id; parent_ids[K - 1] = parent_id; } for (int k = K - 2; k >= 0; --k) { if ((vals[k + 1] > vals[k]) || (ids[k] == -1) || ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) { T tmp_val = vals[k]; int tmp_id = ids[k]; int parent_id2 = parent_ids[k]; vals[k] = vals[k + 1]; ids[k] = ids[k + 1]; parent_ids[k] = parent_ids[k + 1]; vals[k + 1] = tmp_val; ids[k + 1] = tmp_id; parent_ids[k + 1] = parent_id2; } } } }; template __device__ __forceinline__ TopK reduce_topk_op(const TopK &a, const TopK &b) { TopK res = a; for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]); return res; } template struct TopKSoftMax { DySoftMaxStruct softmax_md; TopK topk; }; template __device__ __forceinline__ TopKSoftMax reduce_topk_softmax_op( const TopKSoftMax &a, const TopKSoftMax &b) { TopKSoftMax res; // max_logit in block res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md); res.topk = reduce_topk_op(a.topk, b.topk); return res; } struct __align__(8) MD { float m; float d; }; __device__ __forceinline__ MD reduce_md_op(MD a, MD b) { bool const isABigger = a.m > b.m; MD const bigger = isABigger ? a : b; MD const smaller = isABigger ? b : a; MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)}; return res; } template __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_search_softmax_topk_stage1_fast(const T *logits, float *tmp_buffer, const int *end_ids, const bool *beam_finished, const int *seq_lens, int beam_width, int beam_group_idx, int vocab_size, int vocab_chunk_size) { constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2; const int beam_group_size = K / 2; const int tid = threadIdx.x; const int group_beam_batch_id = blockIdx.x; const int batch_id = group_beam_batch_id / beam_group_size; const int beam_group_sub_id = group_beam_batch_id % beam_group_size; const int beam_batch_id = batch_id * beam_width + beam_group_idx * beam_group_size + beam_group_sub_id; const int seq_len = seq_lens[beam_batch_id]; const bool finished = beam_finished[beam_batch_id]; if (seq_len < 0 || finished) { return; } const int section_start = vocab_chunk_size * blockIdx.y; const int section_end = std::min(section_start + vocab_chunk_size, vocab_size); const int valid_smem_length = section_end - section_start; T const MAX_T_VAL = 1e38; // Load element from logits to smemLogProbs, doing reduce_md and argmax // meanwhile Each thread is responsible for `vocab_chunk_size / // THREADBLOCK_SIZE` elements extern __shared__ char smem[]; T *smemLogProbs = reinterpret_cast(smem); MD partial_md{-MAX_T_VAL, 0.0f}; using KVPair = cub::KeyValuePair; KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL}; cub::ArgMax argmax; T const *local_logits = logits + beam_batch_id * vocab_size; #pragma unroll 1 for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) { T const val = local_logits[i]; const int smem_index = i - section_start; smemLogProbs[smem_index] = val; MD new_elem_md{val, 1.0F}; partial_md = reduce_md_op(partial_md, new_elem_md); KVPair new_elem_topk{smem_index, val}; topKVPairPartial = argmax(topKVPairPartial, new_elem_topk); } __syncthreads(); // Search the top 2K elements among `vocab_chunk_size` elements of this // ThreadBlock and write into smemOutput __shared__ float smemOutput[PACKED_TOP_KMD_SIZE]; __shared__ int threadToUpdate; using BlockReduceMD = cub::BlockReduce; using BlockReduceTopK = cub::BlockReduce; __shared__ union { typename BlockReduceTopK::TempStorage topk; typename BlockReduceMD::TempStorage md; } smemReduceBuffer; for (int i = 0; i < 2 * beam_group_size; ++i) { // Pop the element with largest value to "smemOutput" per iteration KVPair topKVPair = BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax); if (tid == 0) { // const int index = beam_batch_id * vocab_size + section_start + const int index = section_start + topKVPair.key; reinterpret_cast(smemOutput)[i] = index; smemOutput[K + i] = topKVPair.value; smemLogProbs[topKVPair.key] = -MAX_T_VAL; // pollute the value of the popped element threadToUpdate = topKVPair.key % THREADBLOCK_SIZE; } __syncthreads(); if (tid == threadToUpdate && i < 2 * beam_group_size - 1) { // The thread popped the element need to update its topKVPairPartial // No need to do this in the last iteration topKVPairPartial.key = vocab_size - 1; topKVPairPartial.value = -MAX_T_VAL; for (int index = tid; index < valid_smem_length; index += THREADBLOCK_SIZE) { topKVPairPartial = argmax(topKVPairPartial, {index, smemLogProbs[index]}); } } } // Do reduce_md among the top 2K elements in the smemOutput and write into // tail of smemOutput auto reduce_md_func = [](const MD &a, const MD &b) { return reduce_md_op(a, b); }; MD total_md = BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func); if (tid == 0) { smemOutput[2 * K] = total_md.d; smemOutput[2 * K + 1] = total_md.m; } __syncthreads(); // Write the smemOutput into tmp_buffer float *local_temp_buffer = tmp_buffer + group_beam_batch_id * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE; #pragma unroll for (int i = tid; i < PACKED_TOP_KMD_SIZE; i += THREADBLOCK_SIZE) { local_temp_buffer[i] = smemOutput[i]; } } // <<<(batch_size * beam_group_size, voc_parts), 128>>> template __global__ void beam_search_softmax_topk_stage1(BeamSearchParams params, const int beam_width, const int beam_group_idx, const int vocab_size, const bool fuse_softmax) { const int thread_id = threadIdx.x; const int beam_group_size = K / 2; const int batch_id = blockIdx.x / beam_group_size; const int beam_group_sub_idx = blockIdx.x % beam_group_size; const int beam_batch_id = batch_id * beam_width + beam_group_idx * beam_group_size + beam_group_sub_idx; const bool finish = params.beam_finished[beam_batch_id]; const int seq_len = params.seq_lens[beam_batch_id]; // for dybatch if (seq_len < 0 || finish) { return; } // 2 * K + 2 __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; const T MAX_T_VAL = FLT_MAX; const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y; const int section_start = v_local * blockIdx.y; int section_end = section_start + v_local; section_end = (section_end > vocab_size) ? vocab_size : section_end; T *logits = params.logits + beam_batch_id * vocab_size; if (fuse_softmax) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; TopKSoftMax partial; for (int i = 0; i < K; ++i) { partial.topk.ids[i] = -1; partial.topk.vals[i] = -MAX_T_VAL; } partial.softmax_md.logit = -MAX_T_VAL; partial.softmax_md.score = 0.0F; // process voc_parts #pragma unroll 1 for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { T elem = logits[elem_id]; DySoftMaxStruct new_elem{elem, 1.0F}; partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem); partial.topk.insert(elem, elem_id); } // === old_beam_search strategy === // } // reduce voc_parts TopKSoftMax total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); if (thread_id == 0) { for (int i = 0; i < K; i++) { reinterpret_cast(buf_s)[i] = total.topk.ids[i]; buf_s[K + i] = total.topk.vals[i]; } buf_s[2 * K] = total.softmax_md.score; buf_s[2 * K + 1] = total.softmax_md.logit; } } else { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; TopK partial; for (int i = 0; i < K; ++i) { partial.ids[i] = -1; partial.vals[i] = -MAX_T_VAL; } #pragma unroll 1 for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) { T elem = logits[elem_id]; partial.insert(elem, elem_id); } TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); if (thread_id == 0) { for (int i = 0; i < K; i++) { reinterpret_cast(buf_s)[i] = total.ids[i]; buf_s[K + i] = total.vals[i]; } } } __syncthreads(); // write all the voc_parts results to tmp_buffer for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) { params.tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; } } template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_search_softmax_topk_stage2_fast( int *__restrict tmp_ids, T *__restrict tmp_vals, float *__restrict tmp_buffer, const float *__restrict cum_scores, const bool *__restrict beam_finished, const int *__restrict seq_lens, const int beam_width, const int beam_group_idx, const int vocab_size, const int voc_parts) { constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2; constexpr int beam_group_size = K / 2; const int group_beam_batch_id = blockIdx.x; const int beam_group_sub_id = blockIdx.x % beam_group_size; const int batch_size = group_beam_batch_id / beam_group_size; const int beam_batch_id = batch_size * beam_width + beam_group_idx * beam_group_size + beam_group_sub_id; if (seq_lens[beam_batch_id] < 0 || beam_finished[beam_batch_id]) { return; } const int tid = threadIdx.x; T const MAX_T_VAL = FLT_MAX; using KVPair = cub::KeyValuePair; using BlockReduceTopK = cub::BlockReduce; using BlockReduceMD = cub::BlockReduce; __shared__ KVPair buf_smem_kv[K]; __shared__ union { typename BlockReduceTopK::TempStorage topk; typename BlockReduceMD::TempStorage md; } smemReduceBuffer; cub::ArgMax argmax; MD partial_md{-MAX_T_VAL, 0.0f}; KVPair topKVPair{vocab_size - 1, -MAX_T_VAL}; auto reduce_md_func = [](const MD &a, const MD &b) { return reduce_md_op(a, b); }; // Load and unpack into registers through smem float *localTempBuffer = tmp_buffer + PACKED_TOP_KMD_SIZE * group_beam_batch_id * voc_parts; if constexpr (IS_FAST_KERNEL) { // Use share memory instead of global memory extern __shared__ char smem[]; float *smemVal = reinterpret_cast(smem); for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * voc_parts; idx += THREADBLOCK_SIZE) { smemVal[idx] = localTempBuffer[idx]; } localTempBuffer = smemVal; __syncthreads(); } // Find the top 2K across all voc_parts for (int k = 0; k < K; ++k) { KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL}; // Only threads responsible for a chunk will do the computation if (tid < voc_parts) { for (int i = 0; i < K; ++i) { const int current_index = tid * PACKED_TOP_KMD_SIZE + i; T topValue = localTempBuffer[current_index + K]; topKVPairPartial = argmax(topKVPairPartial, {current_index, topValue}); } } KVPair topKVPair = BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax); __syncthreads(); if (tid == 0) { // Store kv pairs in shared mem buffer int temp_offset = topKVPair.key; int global_offset = reinterpret_cast(localTempBuffer)[temp_offset]; topKVPair.key = global_offset; buf_smem_kv[k] = topKVPair; // Invalidate the maximum value within the chunk reinterpret_cast(localTempBuffer)[temp_offset] = vocab_size - 1; // id in share memory localTempBuffer[temp_offset + K] = -MAX_T_VAL; // value in share memory } __syncthreads(); } // Extract and reduce MD values across the chunks if (tid < voc_parts) { partial_md.d = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K]; partial_md.m = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K + 1]; } __syncthreads(); MD total_md = BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func); if (tid == 0) { float d_total_log = logf(total_md.d); for (int i = 0; i < K; ++i) { float val = static_cast(buf_smem_kv[i].value) - total_md.m - d_total_log; tmp_ids[group_beam_batch_id * K + i] = buf_smem_kv[i].key; tmp_vals[group_beam_batch_id * K + i] = val + cum_scores[beam_batch_id]; } } } #define BEAM_STAGE2_KERNEL(N_VOCAB_PART, IS_FAST_KERNEL) \ do { \ if (IS_FAST_KERNEL && nShareMemory >= (48 << 10)) { \ cudaFuncSetAttribute( \ beam_search_softmax_topk_stage2_fast, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ nShareMemory); \ } \ beam_search_softmax_topk_stage2_fast \ <<>>(params->tmp_ids, \ params->tmp_vals, \ params->tmp_buffer, \ params->cum_scores, \ params->beam_finished, \ params->seq_lens, \ beam_width, \ beam_group_idx, \ vocab_size, \ voc_parts); \ } while (0); \ return; template __inline__ void beamSearchSoftmaxTopkStage2FastKernelLauncher( BeamSearchParams *params, const int batch_size, const int beam_width, const int beam_group_idx, const int vocab_size, const int voc_parts, const int max_smem_per_block, cudaStream_t stream) { constexpr int beam_group_size = K / 2; size_t const nShareMemory = sizeof(float) * voc_parts * (2 * K + 2) + sizeof(cub::KeyValuePair) * K; if (nShareMemory < max_smem_per_block) { // IS_FAST_KERNEL must be a // compilation-time constant if (voc_parts <= 32) { BEAM_STAGE2_KERNEL(32, true) } if (voc_parts <= 64) { BEAM_STAGE2_KERNEL(64, true) } BEAM_STAGE2_KERNEL(128, true) // No larger branch since voc_parts <= nMaxVocabPartForStage1FastKernel } BEAM_STAGE2_KERNEL(128, false) } template __global__ void beam_search_softmax_topk_stage2(BeamSearchParams params, const int beam_width, const int beam_group_idx, const int voc_parts, const int packed_top_kmd_size, const bool fuse_softmax) { const int thread_id = threadIdx.x; const int beam_group_size = K / 2; const int batch_id = blockIdx.x / beam_group_size; const int beam_group_sub_idx = blockIdx.x % beam_group_size; // int vector_id = blockIdx.x; // batch beam index. const int beam_batch_id = batch_id * beam_width + beam_group_idx * beam_group_size + beam_group_sub_idx; const int group_beam_batch_id = blockIdx.x; // const int vector_id = blockIdx.x; const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size; // for dybatch const int seq_len = params.seq_lens[beam_batch_id]; const bool finish = params.beam_finished[beam_batch_id]; int *tmp_ids = params.tmp_ids + group_beam_batch_id * K; float *tmp_vals = params.tmp_vals + group_beam_batch_id * K; float *tmp_buffer = params.tmp_buffer; const T *cum_scores = params.cum_scores + beam_batch_id; if (seq_len < 0 || finish) { return; } const T MAX_T_VAL = FLT_MAX; extern __shared__ char buf_s_[]; float *buf_s = reinterpret_cast(buf_s_); // 当前 batch beam 的所有 voc tmp_buffer += group_beam_batch_id * PACKED_TOP_KMD_SIZE * voc_parts; if (fuse_softmax) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; TopKSoftMax partial; for (int i = 0; i < K; ++i) { partial.topk.ids[i] = -1; partial.topk.vals[i] = -MAX_T_VAL; } partial.softmax_md.logit = -MAX_T_VAL; partial.softmax_md.score = 0.0F; for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; idx += THREADBLOCK_SIZE) { buf_s[idx] = tmp_buffer[idx]; } __syncthreads(); if (threadIdx.x < voc_parts) { float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; for (int i = 0; i < K; i++) { partial.topk.ids[i] = reinterpret_cast(b_s)[i]; partial.topk.vals[i] = b_s[K + i]; } partial.softmax_md.score = b_s[2 * K]; partial.softmax_md.logit = b_s[2 * K + 1]; } __syncthreads(); TopKSoftMax total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op); if (thread_id == 0) { // tmp_ids += group_beam_batch_id * K; // tmp_vals += group_beam_batch_id * K; float d_total_log = logf(total.softmax_md.score); for (int i = 0; i < K; ++i) { // float val = expf((float)total.topk.vals[i] - total.softmax_md.logit - // d_total_log); float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log; tmp_ids[i] = total.topk.ids[i]; tmp_vals[i] = val + params.cum_scores[beam_batch_id]; } } } else { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; TopK partial; for (int i = 0; i < K; ++i) { partial.ids[i] = -1; partial.vals[i] = -MAX_T_VAL; } for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts; idx += THREADBLOCK_SIZE) { buf_s[idx] = tmp_buffer[idx]; } __syncthreads(); if (threadIdx.x < voc_parts) { float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE; for (int i = 0; i < K; i++) { partial.ids[i] = reinterpret_cast(b_s)[i]; partial.vals[i] = b_s[K + i]; } } __syncthreads(); TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); if (thread_id == 0) { tmp_ids += group_beam_batch_id * K; tmp_vals += group_beam_batch_id * K; for (int i = 0; i < K; ++i) { float val = total.vals[i]; tmp_ids[i] = total.ids[i]; tmp_vals[i] = val + params.cum_scores[beam_batch_id]; } } } } template void invokeBeamSearchSoftmaxTopKStage2(BeamSearchParams *params, const int batch_size, const int beam_width, const int beam_group_idx, const int voc_parts, const int packed_top_kmd_size, const bool fuse_softmax, gpuStream_t stream) { int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float); const int beam_group_size = K / 2; if (voc_parts <= 32) { beam_search_softmax_topk_stage2 <<>>( *params, beam_width, beam_group_idx, voc_parts, packed_top_kmd_size, fuse_softmax); return; } if (voc_parts <= 64) { beam_search_softmax_topk_stage2 <<>>( *params, beam_width, beam_group_idx, voc_parts, packed_top_kmd_size, fuse_softmax); return; } if (voc_parts <= 128) { beam_search_softmax_topk_stage2 <<>>( *params, beam_width, beam_group_idx, voc_parts, packed_top_kmd_size, fuse_softmax); return; } if (voc_parts <= 256) { beam_search_softmax_topk_stage2 <<>>( *params, beam_width, beam_group_idx, voc_parts, packed_top_kmd_size, fuse_softmax); return; } } template __global__ void update_beam_finished_early_stop(const T *beam_hyps_score_out, bool *beam_finished) { if (threadIdx.x == 0) { int batch_idx = blockIdx.x; const T *cur_beam_hyps_score = beam_hyps_score_out + batch_idx * K; bool *cur_beam_finished = beam_finished + batch_idx * K; if (cur_beam_hyps_score[K - 1] > -1e8) { for (int i = 0; i < K; i++) { cur_beam_finished[i] = true; } } } } // <<>> template __global__ void batch_topk(BeamSearchParams params, const int beam_width, const int beam_group_idx, const int dec_stride) { const bool early_stop = params.early_stop; const int thread_id = threadIdx.x; const int batch_id = blockIdx.x; // int block_id = blockIdx.x; // bs const int beam_group_size = K / 2; const int beam_group_start_id = batch_id * beam_width + beam_group_idx * beam_group_size; bool *beam_finished = params.beam_finished + beam_group_start_id; const int *step_ids = params.step_ids + beam_group_start_id; int *next_tokens = params.next_tokens + beam_group_start_id; float *cum_scores_out = params.cum_scores_out + beam_group_start_id; int *parent_ids = params.parent_ids + beam_group_start_id; float *beam_hyps_score_out = params.beam_hyps_score_out + beam_group_start_id; const bool finish = beam_finished[0]; const int step_id = step_ids[0]; const int seq_len = params.seq_lens[beam_group_start_id]; const int max_dec_len = params.max_dec_lens[beam_group_start_id]; const bool last_dec_step = (step_id + 1 == max_dec_len); if (thread_id == 0 && seq_len > 0 && !finish) { TopK partial; BeamHypothesesTopK beam_hyps; beam_hyps.init(params.beam_hyps_out + beam_group_start_id * dec_stride, params.beam_hyps_score_out + beam_group_start_id, dec_stride); for (int i = 0; i < K; ++i) { partial.ids[i] = -1; partial.vals[i] = -FLT_MAX; partial.parent_ids[i] = -1; } int index = batch_id * beam_group_size * K; if (step_id == 0) { for (int i = 0; i < K; i++) { float score_now = apply_length_penalty(params.tmp_vals[index + i], step_id + 1, params.length_penalty[batch_id]); if (!GROUP) { score_now -= params.diversity_penalty[batch_id] * static_cast(i + 1); } partial.insert((T)score_now, params.tmp_ids[index + i], i / K); } } else { for (int i = 0; i < beam_group_size * K; i++) { float score_now = apply_length_penalty(params.tmp_vals[index + i], step_id + 1, params.length_penalty[batch_id]); if (!GROUP) { score_now -= params.diversity_penalty[batch_id] * static_cast(i % K + 1); } partial.insert((T)score_now, params.tmp_ids[index + i], i / K); } } if (partial.vals[0] < beam_hyps.hyps[beam_group_size - 1].score) { for (int i = 0; i < beam_group_size; i++) { beam_finished[i] = true; } return; } int next_step_num = 0; for (int i = 0; i < K && next_step_num < beam_group_size; i++) { int parent_id = partial.parent_ids[i]; if (is_in_end(partial.ids[i], params.end_ids, params.end_ids_len) || last_dec_step) { if (i < beam_group_size && partial.vals[i] > beam_hyps.get_worst_score()) { const int *beam_cache_id = params.beam_cache_ids + beam_group_start_id * dec_stride + parent_id * dec_stride; beam_hyps.insert(beam_cache_id, step_id, last_dec_step ? params.end_ids[0] : partial.ids[i], partial.vals[i]); } if (early_stop && beam_hyps.get_worst_score() > -1e8) { // stop for (int i = 0; i < beam_group_size; i++) { beam_finished[i] = true; } return; } } else { next_tokens[next_step_num] = partial.ids[i]; cum_scores_out[next_step_num] = partial.vals[i]; parent_ids[next_step_num] = parent_id; next_step_num += 1; } } for (int i = 0; i < beam_group_size; i++) { beam_hyps_score_out[i] = beam_hyps.hyps[i].score; } if (last_dec_step) { for (int i = 0; i < beam_group_size; i++) { beam_finished[i] = true; } } } // if (thread_id == 0) } template void invokeTopKSoftMaxLauncher(BeamSearchParams *params, int beam_group_idx, gpuStream_t stream) { const int batch_size = params->batch_size; const int beam_width = params->beam_width; const int beam_group_size = K / 2; const int vocab_size = params->vocab_size; const bool fuse_softmax = params->fuse_softmax; const int voc_parts = params->voc_parts; constexpr int dev_id = 0; // only in group_beam_search if (beam_width > beam_group_size && beam_group_idx != 0) { apply_group_diversity_penalty <<>>( *params, batch_size, beam_width, beam_group_idx, vocab_size); } // == Step1 == : stage1 if (params->use_fast_kernel) { constexpr int block_size = (K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64; const int vocab_chunk_size = (vocab_size + voc_parts - 1) / voc_parts; const int dyn_smem_size = sizeof(T) * vocab_chunk_size; if (dyn_smem_size >= (48 << 10)) { cudaFuncSetAttribute( beam_search_softmax_topk_stage1_fast, cudaFuncAttributeMaxDynamicSharedMemorySize, dyn_smem_size); } dim3 grid(batch_size * beam_group_size, voc_parts); beam_search_softmax_topk_stage1_fast <<>>(params->logits, params->tmp_buffer, params->end_ids, params->beam_finished, params->seq_lens, beam_width, beam_group_idx, vocab_size, vocab_chunk_size); } else { constexpr int block_size = 128; dim3 grid(batch_size * beam_group_size, voc_parts); cudaFuncSetAttribute( beam_search_softmax_topk_stage1, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1); if (fuse_softmax) { #ifdef PADDLE_WITH_CUDA cudaFuncSetAttribute( beam_search_softmax_topk_stage1, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1); #else // cudaSharedmemCarveoutMaxL1 equal to 0 hipFuncSetAttribute( reinterpret_cast( beam_search_softmax_topk_stage1), hipFuncAttributePreferredSharedMemoryCarveout, 0); #endif // (bs, bm, voc_parts, 2 * K + 2) beam_search_softmax_topk_stage1 <<>>( *params, beam_width, beam_group_idx, vocab_size, fuse_softmax); } else { #ifdef PADDLE_WITH_CUDA cudaFuncSetAttribute( beam_search_softmax_topk_stage1, cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1); #else // cudaSharedmemCarveoutMaxL1 equal to 0 hipFuncSetAttribute( reinterpret_cast( beam_search_softmax_topk_stage1), hipFuncAttributePreferredSharedMemoryCarveout, 0); #endif // (bs, bm, voc_parts, 2 * K) beam_search_softmax_topk_stage1 <<>>( *params, beam_width, beam_group_idx, vocab_size, fuse_softmax); } } beamSearchSoftmaxTopkStage2FastKernelLauncher( params, batch_size, beam_width, beam_group_idx, vocab_size, voc_parts, params->max_smem_per_block, stream); batch_topk<<>>( *params, beam_width, beam_group_idx, params->dec_stride); } template void invokeTopkSoftMax(BeamSearchParams *params, int beam_group_idx, gpuStream_t stream) { switch (params->beam_group_size) { CASE_K(1); CASE_K(2); CASE_K(3); CASE_K(4); CASE_K(5); CASE_K(6); CASE_K(7); CASE_K(8); CASE_K(9); CASE_K(10); CASE_K(11); CASE_K(12); CASE_K(13); CASE_K(14); CASE_K(15); CASE_K(16); } } template void ComputeVocParts(BeamSearchParams *params) { int dev_id = 0; const int block_size = (K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64; int max_active_blocks = -1; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, beam_search_softmax_topk_stage1_fast, block_size, 0); int max_smem_per_sm = -1; int max_smem_per_block = -1; cudaDeviceGetAttribute( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id); cudaDeviceGetAttribute( &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id); cudaFuncAttributes attr; cudaFuncGetAttributes( &attr, beam_search_softmax_topk_stage1_fast); const int static_smem = attr.sharedSizeBytes; const int max_dyn_smem_per_block = max_smem_per_block - static_smem; if (sizeof(T) * params->vocab_size > max_dyn_smem_per_block * kMaxVocabPartForStage1FastKernel) { } const int driver_smem_per_block = max_smem_per_sm - max_smem_per_block; const int extra_smem = driver_smem_per_block + static_smem; int voc_parts = kMaxVocabPartForStage1FastKernel + 1; for (int n_block = max_active_blocks - 1; n_block > 0 && voc_parts > kMaxVocabPartForStage1FastKernel; --n_block) { int dyn_smem_size = max_smem_per_sm / n_block - extra_smem; dyn_smem_size -= dyn_smem_size % sizeof(T); voc_parts = ceilDiv(sizeof(T) * params->vocab_size, dyn_smem_size); } if (!params->fuse_softmax || voc_parts > kMaxVocabPartForStage1FastKernel) { params->use_fast_kernel = false; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); const int max_act_blocks_per_sm = 4; const int max_act_blocks_per_wave = sm_count * max_act_blocks_per_sm; const int gridx = params->batch_size * K / 2; const int max_part_num = (max_act_blocks_per_wave + gridx - 1) / gridx; voc_parts = min(128, max_part_num); } params->voc_parts = voc_parts; params->max_smem_per_block = max_smem_per_block; } template void DispatchComputeVocParts(BeamSearchParams *params) { switch (params->beam_group_size) { DISPATCH_COMPUTE_PARTS_K(1); DISPATCH_COMPUTE_PARTS_K(2); DISPATCH_COMPUTE_PARTS_K(3); DISPATCH_COMPUTE_PARTS_K(4); DISPATCH_COMPUTE_PARTS_K(5); DISPATCH_COMPUTE_PARTS_K(6); DISPATCH_COMPUTE_PARTS_K(7); DISPATCH_COMPUTE_PARTS_K(8); DISPATCH_COMPUTE_PARTS_K(9); DISPATCH_COMPUTE_PARTS_K(10); DISPATCH_COMPUTE_PARTS_K(11); DISPATCH_COMPUTE_PARTS_K(12); DISPATCH_COMPUTE_PARTS_K(13); DISPATCH_COMPUTE_PARTS_K(14); DISPATCH_COMPUTE_PARTS_K(15); DISPATCH_COMPUTE_PARTS_K(16); } } template __global__ void update_beam_search_params_kernel(BeamSearchParams params) { int bb_id = blockIdx.y; int time_step = threadIdx.x + blockIdx.x * blockDim.x; const bool finished = params.beam_finished[bb_id]; const int seq_len = params.seq_lens[bb_id]; if (bb_id >= params.beam_width * params.batch_size) { return; } if (finished || seq_len < 0) { return; } const int beam_group_size = params.beam_group_size; const int max_seq_len = params.max_seq_len; const int dec_stride = params.dec_stride; const int batch_group_id = bb_id / beam_group_size; const int max_dec_len = params.max_dec_lens[bb_id]; const int src_beam = params.parent_ids[bb_id]; const int step = params.step_ids[bb_id]; const int *block_tables = params.block_tables; int *block_tables_out = params.block_tables_out; const int *cache_ids = params.beam_cache_ids; int *cache_ids_out = params.cache_ids_out; const int *next_tokens = params.next_tokens; const int beam_group_sub_id = bb_id % beam_group_size; // const int src_bb_id = batch_group_id * beam_group_size + src_beam; if (time_step < min(max_seq_len, seq_len + 1)) { const unsigned int block_tables_tgt_offset = batch_group_id * beam_group_size * max_seq_len + beam_group_sub_id * max_seq_len + time_step; const unsigned int block_tables_src_offset = batch_group_id * beam_group_size * max_seq_len + src_beam * max_seq_len + time_step; block_tables_out[block_tables_tgt_offset] = block_tables[block_tables_src_offset]; if (time_step < min(step + 1, max_dec_len)) { const unsigned int cache_ids_tgt_offset = batch_group_id * beam_group_size * dec_stride + beam_group_sub_id * dec_stride + time_step; const unsigned int cache_ids_src_offset = batch_group_id * beam_group_size * dec_stride + src_beam * dec_stride + time_step; cache_ids_out[cache_ids_tgt_offset] = (time_step == step) ? next_tokens[bb_id] : cache_ids[cache_ids_src_offset]; } } } template __global__ void update_stop_flags(BeamSearchParams params) { int bid = blockIdx.x; const int beam_width = params.beam_width; const int beam_group_size = params.beam_group_size; const bool *beam_finished = params.beam_finished + beam_width * bid; bool *stop_flags = params.stop_flags + beam_width * bid; bool finished = true; if (threadIdx.x == 0 && !stop_flags[0]) { #pragma unroll for (int i = 0; i < beam_width; i += beam_group_size) { finished &= beam_finished[i]; } if (finished) { #pragma unroll for (int i = 0; i < beam_width; i++) { stop_flags[i] = true; } } } } template void updateBeamSearchParams(BeamSearchParams *params, cudaStream_t stream) { const dim3 block(32); const dim3 grid((params->max_seq_len + block.x - 1) / block.x, params->batch_size * params->beam_width); update_beam_search_params_kernel<<>>(*params); const dim3 grid_2(params->batch_size); update_stop_flags<<>>(*params); } /***** In order to adapt to the model structure of 5.2 without adding while op and without affecting the speed. Use a 'fake inplace' method here. Not elegant but useful ︸_︸. *****/ std::vector BeamSearchSoftmax(const paddle::Tensor &logits, const paddle::Tensor &seq_lens, const paddle::Tensor &stop_flags, // inplace const paddle::Tensor &end_ids, const paddle::Tensor &step_ids, const paddle::Tensor &max_dec_lens, const paddle::Tensor &block_tables, // inplace const paddle::Tensor &cum_scores, // inplace const paddle::Tensor &beam_cache_ids, // inplace const paddle::Tensor &beam_hyps, // inplace const paddle::Tensor &beam_hyps_score, // inplace const paddle::Tensor &beam_finished, // inplace const paddle::Tensor &beam_width, const paddle::Tensor &beam_group_num, const paddle::Tensor &length_penalty, const paddle::Tensor &diversity_penalty, bool fuse_softmax, bool early_stop) { std::vector logits_shape = logits.shape(); // logits_shape auto cu_stream = logits.stream(); int beam_width_scalar; cudaMemcpyAsync(&beam_width_scalar, beam_width.data(), sizeof(int), cudaMemcpyDeviceToHost, cu_stream); int beam_group_num_scalar; cudaMemcpyAsync(&beam_group_num_scalar, beam_group_num.data(), sizeof(int), cudaMemcpyDeviceToHost, cu_stream); int beam_batch_size = logits_shape[0]; int batch_size = beam_batch_size / beam_width_scalar; int vocab_size = logits_shape[1]; const int max_seq_len = block_tables.dims()[1]; // liuzichang: In some cases, the length of Tensor is longer than max_dec_lens const int dec_stride = beam_hyps.dims()[1]; const int end_ids_len = end_ids.dims()[0]; const int beam_group_size = beam_width_scalar / beam_group_num_scalar; auto next_tokens = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); auto parent_ids = paddle::full({logits_shape[0], 1}, 0, end_ids.type(), paddle::GPUPlace()); auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(), paddle::GPUPlace()); auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(), paddle::GPUPlace()); auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(), paddle::GPUPlace()); cudaMemcpyAsync(cum_scores_ori.mutable_data(), cum_scores.data(), sizeof(float)*cum_scores.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(beam_cache_ids_ori.mutable_data(), beam_cache_ids.data(), sizeof(int)*beam_cache_ids.numel(), cudaMemcpyDeviceToDevice, cu_stream); cudaMemcpyAsync(block_tables_ori.mutable_data(), block_tables.data(), sizeof(int)*block_tables.numel(), cudaMemcpyDeviceToDevice, cu_stream); const int tmp_size = batch_size * beam_group_size * beam_group_size * 2; auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(), paddle::GPUPlace()); auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(), paddle::GPUPlace()); BeamSearchParams params; params.batch_size = batch_size; params.beam_width = beam_width_scalar; params.beam_group_size = beam_group_size; params.vocab_size = vocab_size; params.dec_stride = dec_stride; params.max_seq_len = max_seq_len; params.end_ids_len = end_ids_len; params.fuse_softmax = fuse_softmax; params.early_stop = early_stop; // Only Read params.step_ids = step_ids.data(); params.seq_lens = seq_lens.data(); params.max_dec_lens = max_dec_lens.data(); params.end_ids = end_ids.data(); params.length_penalty = length_penalty.data(); params.diversity_penalty = diversity_penalty.data(); params.cum_scores = cum_scores_ori.data(); params.block_tables = block_tables_ori.data(); params.beam_cache_ids = beam_cache_ids_ori.data(); // Write params.logits = const_cast(logits.data()); params.cache_ids_out = const_cast(beam_cache_ids.data()); params.block_tables_out = const_cast(block_tables.data()); params.cum_scores_out = const_cast(cum_scores.data()); params.beam_hyps_out = const_cast(beam_hyps.data()); params.beam_hyps_score_out = const_cast(beam_hyps_score.data()); params.beam_finished = const_cast(beam_finished.data()); params.stop_flags = const_cast(stop_flags.data()); params.next_tokens = const_cast(next_tokens.data()); params.parent_ids = const_cast(parent_ids.data()); params.tmp_ids = tmp_topk_id.data(); params.tmp_vals = tmp_topk_val.data(); DispatchComputeVocParts(¶ms); // allocate workspace const int tmp_id_val_size = batch_size * beam_group_size * beam_group_size * 2; const int packed_top_kmd_size = fuse_softmax ? 2 * 2 * beam_group_size + 2 : 2 * 2 * beam_group_size; const int tmp_stage1_to_stage2_size = batch_size * beam_group_size * params.voc_parts * packed_top_kmd_size; const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size; auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(), paddle::GPUPlace()); params.tmp_ids = reinterpret_cast(wsp_buffer_tensor.data()); params.tmp_vals = wsp_buffer_tensor.data() + tmp_id_val_size; params.tmp_buffer = wsp_buffer_tensor.data() + 2 * tmp_id_val_size; for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar; ++beam_group_idx) { if (beam_group_num_scalar == 1) { invokeTopkSoftMax( ¶ms, beam_group_idx, cu_stream); } else { invokeTopkSoftMax( ¶ms, beam_group_idx, cu_stream); } } updateBeamSearchParams(¶ms, cu_stream); return {next_tokens, parent_ids}; } std::vector> BeamSearchSoftmaxShape( const std::vector &logits, const std::vector &seq_lens, const std::vector &stop_flags, // inplace const std::vector &end_ids, const std::vector &step_ids, const std::vector &max_dec_lens, const std::vector &block_tables, // inplace const std::vector &cum_scores, // inplace const std::vector &beam_cache_ids, // inplace const std::vector &beam_hyps, // inplace const std::vector &beam_hyps_score, // inplace const std::vector &beam_finished, // inplace const std::vector &beam_width, const std::vector &beam_group_num, const std::vector &length_penalty, const std::vector &diversity_penalty) { std::vector next_tokens = {logits[0],1}; std::vector parent_ids = {logits[0],1}; return {next_tokens,parent_ids}; } std::vector BeamSearchSoftmaxDtype( const paddle::DataType &logits, const paddle::DataType &seq_lens, const paddle::DataType &stop_flags, // inplace const paddle::DataType &end_ids, const paddle::DataType &step_ids, const paddle::DataType &max_dec_lens, const paddle::DataType &block_tables, // inplace const paddle::DataType &cum_scores, // inplace const paddle::DataType &beam_cache_ids, // inplace const paddle::DataType &beam_hyps, // inplace const paddle::DataType &beam_hyps_score, // inplace const paddle::DataType &beam_finished, // inplace const paddle::DataType &beam_width, const paddle::DataType &beam_group_num, const paddle::DataType &length_penalty, const paddle::DataType &diversity_penalty) { return {paddle::DataType::INT32, paddle::DataType::INT32}; } PD_BUILD_STATIC_OP(beam_search_softmax) .Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables" , "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished" , "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"}) .Outputs({"next_tokens", "parent_ids"}) .Attrs({"fuse_softmax: bool", "early_stop: bool"}) .SetKernelFn(PD_KERNEL(BeamSearchSoftmax)) .SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape)) .SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype));