mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-25 17:40:35 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			1546 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			1546 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 | ||
| //
 | ||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||
| // you may not use this file except in compliance with the License.
 | ||
| // You may obtain a copy of the License at
 | ||
| //
 | ||
| //     http://www.apache.org/licenses/LICENSE-2.0
 | ||
| //
 | ||
| // Unless required by applicable law or agreed to in writing, software
 | ||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| // See the License for the specific language governing permissions and
 | ||
| // limitations under the License.
 | ||
| 
 | ||
| #ifdef __NVCC__
 | ||
| #include <cub/cub.cuh>
 | ||
| #endif
 | ||
| #ifdef __HIPCC__
 | ||
| #include <hipcub/hipcub.hpp>
 | ||
| namespace cub = hipcub;
 | ||
| #endif
 | ||
| #include <fcntl.h>
 | ||
| #include <stdio.h>
 | ||
| #include <stdlib.h>
 | ||
| #include <string.h>
 | ||
| #include <sys/mman.h>
 | ||
| #include <sys/stat.h>
 | ||
| #include <sys/types.h>
 | ||
| #include <unistd.h>
 | ||
| #include <algorithm>
 | ||
| #include "stdint.h"
 | ||
| #include "helper.h"
 | ||
| 
 | ||
| #define FLT_MAX 1e38
 | ||
| 
 | ||
| static constexpr int kBlockSizeForSmallBeamWidth = 256;
 | ||
| static constexpr int kMaxVocabPartForStage1FastKernel = 128;
 | ||
| 
 | ||
| #define CASE_K(K)                                        \
 | ||
|   case K:                                                \
 | ||
|     invokeTopKSoftMaxLauncher<T, 2 * K, GROUP>( \
 | ||
|         params, beam_group_idx, stream);        \
 | ||
|     break
 | ||
| 
 | ||
| #define DISPATCH_COMPUTE_PARTS_K(K)    \
 | ||
|   case K:                              \
 | ||
|     ComputeVocParts<T, 2 * K>(params); \
 | ||
|     break
 | ||
| 
 | ||
| template <typename T>
 | ||
| struct BeamSearchParams {
 | ||
|   // Scalar values
 | ||
|   int batch_size{0};
 | ||
|   int beam_width{0};
 | ||
|   int beam_group_size{0};
 | ||
|   int beam_group_idx{0};
 | ||
| 
 | ||
|   int vocab_size{0};
 | ||
|   int dec_stride{0};
 | ||
|   int max_seq_len{0};
 | ||
|   int end_ids_len{0};
 | ||
| 
 | ||
|   bool fuse_softmax{true};
 | ||
|   bool early_stop{false};
 | ||
| 
 | ||
|   int voc_parts{0};
 | ||
|   bool use_fast_kernel{true};
 | ||
|   int max_smem_per_block{0};
 | ||
| 
 | ||
|   T *logits{nullptr};
 | ||
|   const int *step_ids{nullptr};  // [BS * BM, 1]
 | ||
|   const int *seq_lens{nullptr};  // [BS * BM, 1]
 | ||
| 
 | ||
|   const int *max_dec_lens{nullptr};
 | ||
|   const int *end_ids{nullptr};
 | ||
| 
 | ||
|   const T *cum_scores{nullptr};
 | ||
|   const int *block_tables{nullptr};
 | ||
|   const int *beam_cache_ids{nullptr};
 | ||
| 
 | ||
|   const float *length_penalty{nullptr};     // [BS, 1]
 | ||
|   const float *diversity_penalty{nullptr};  // [BS, 1]
 | ||
| 
 | ||
|   bool *stop_flags{nullptr};        // [BS, 1]
 | ||
|   int *cache_ids_out{nullptr};      // [BS * BM, max_dec_len]
 | ||
|   bool *beam_finished{nullptr};     // [BS * BM, 1]
 | ||
|   int *block_tables_out{nullptr};   // [BS * BM, max_seq_len]
 | ||
|   T *cum_scores_out{nullptr};       // [BS * BM, 1]
 | ||
|   int *beam_hyps_out{nullptr};      // [BS * BM, max_dec_len]
 | ||
|   T *beam_hyps_score_out{nullptr};  // [BS * BM, 1]
 | ||
| 
 | ||
|   // func out
 | ||
|   int *next_tokens{nullptr};
 | ||
|   int *parent_ids{nullptr};
 | ||
| 
 | ||
|   // workspace
 | ||
|   int *tmp_ids{nullptr};
 | ||
|   T *tmp_vals{nullptr};
 | ||
|   T *tmp_buffer{nullptr};
 | ||
| };
 | ||
| 
 | ||
| template <typename T,
 | ||
|           typename U,
 | ||
|           typename = std::enable_if_t<std::is_integral<T>::value>,
 | ||
|           typename = std::enable_if_t<std::is_integral<U>::value>>
 | ||
| auto constexpr ceilDiv(T numerator, U denominator) {
 | ||
|   return (numerator + denominator - 1) / denominator;
 | ||
| }
 | ||
| 
 | ||
| __device__ bool is_in_end(const int id, const int *end_ids, int length) {
 | ||
|   bool flag = false;
 | ||
|   for (int i = 0; i < length; i++) {
 | ||
|     if (id == end_ids[i]) {
 | ||
|       return true;
 | ||
|     }
 | ||
|   }
 | ||
|   return flag;
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| __device__ __forceinline__ T apply_length_penalty(T log_prob,
 | ||
|                                                   int length,
 | ||
|                                                   float length_penalty) {
 | ||
|   // score = log(prob) / (length)^length_penalty.
 | ||
|   if (length_penalty == 0.0f || length == 1) {
 | ||
|     return log_prob;
 | ||
|   }
 | ||
|   return log_prob / static_cast<T>(powf(length, length_penalty));
 | ||
| }
 | ||
| 
 | ||
| // <<<batch_size, beam_group_size>>>
 | ||
| template <typename T, int K>
 | ||
| __global__ void apply_group_diversity_penalty(BeamSearchParams<T> params,
 | ||
|                                               const int batch_size,
 | ||
|                                               const int beam_width,
 | ||
|                                               const int beam_group_idx,
 | ||
|                                               const int vocab_size) {
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int batch_idx = blockIdx.x;
 | ||
|   const int beam_group_sub_idx = threadIdx.x;
 | ||
|   const bool *beam_finished = params.beam_finished + batch_idx * beam_width;
 | ||
|   T *logtis = params.logits + batch_idx * beam_width * vocab_size +
 | ||
|               beam_group_idx * beam_group_size * vocab_size +
 | ||
|               beam_group_sub_idx * vocab_size;
 | ||
|   int *next_tokens = params.next_tokens + batch_idx * beam_width;
 | ||
|   // apply previous group token ids penalty
 | ||
| #pragma unroll
 | ||
|   for (int token_idx = 0; token_idx < beam_group_idx * beam_group_size;
 | ||
|        ++token_idx) {
 | ||
|     const bool finished = beam_finished[token_idx];
 | ||
|     if (!finished) {
 | ||
|       const int token_id = next_tokens[token_idx];
 | ||
|       logtis[token_id] -= params.diversity_penalty[batch_idx];
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| struct DySoftMaxStruct {
 | ||
|   float logit;
 | ||
|   float score;
 | ||
| };
 | ||
| 
 | ||
| __device__ __forceinline__ DySoftMaxStruct
 | ||
| reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) {
 | ||
|   bool a_bigger = (a.logit > b.logit);
 | ||
|   DySoftMaxStruct bigger_m = a_bigger ? a : b;
 | ||
|   DySoftMaxStruct smaller_m = a_bigger ? b : a;
 | ||
|   DySoftMaxStruct res;
 | ||
|   res.score =
 | ||
|       bigger_m.score + smaller_m.score * expf(smaller_m.logit - bigger_m.logit);
 | ||
|   res.logit = bigger_m.logit;
 | ||
|   return res;
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| struct BeamHypothesis {
 | ||
|   T score;
 | ||
|   int *seq;
 | ||
|   int seq_len;
 | ||
| 
 | ||
|   __device__ __forceinline__ void init(int *_seq,
 | ||
|                                        T _score,
 | ||
|                                        const int _max_seq_len) {
 | ||
|     seq = _seq;
 | ||
|     score = _score;
 | ||
|     seq_len = _max_seq_len;
 | ||
|   }
 | ||
| };
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| struct BeamHypothesesTopK {
 | ||
|   BeamHypothesis<T> hyps[K];
 | ||
|   int max_dec_len;
 | ||
| 
 | ||
|   __device__ __forceinline__ void init(int *_beam_hyps,
 | ||
|                                        T *_beam_hyps_score,
 | ||
|                                        const int _max_dec_len) {
 | ||
|     max_dec_len = _max_dec_len;
 | ||
|     for (int i = 0; i < K; i++) {
 | ||
|       // 使用默认构造函数创建默认的 BeamHypothesis 对象
 | ||
|       hyps[i].init(
 | ||
|           _beam_hyps + i * _max_dec_len, _beam_hyps_score[i], _max_dec_len);
 | ||
|     }
 | ||
|   }
 | ||
| 
 | ||
|   __device__ void insert(const int *token_ids,
 | ||
|                          int step,
 | ||
|                          int cur_token_id,
 | ||
|                          T score) {
 | ||
|     if (score > get_worst_score()) {
 | ||
|       for (int i = 0; i < step; i++) {
 | ||
|         hyps[K - 1].seq[i] = token_ids[i];
 | ||
|       }
 | ||
|       hyps[K - 1].seq[step] = cur_token_id;
 | ||
|       hyps[K - 1].score = score;
 | ||
| 
 | ||
|       for (int k = K - 2; k >= 0; --k) {
 | ||
|         if (hyps[k + 1].score > hyps[k].score) {
 | ||
|           T tmp_score = hyps[k].score;
 | ||
|           hyps[k].score = hyps[k + 1].score;
 | ||
|           hyps[k + 1].score = tmp_score;
 | ||
| 
 | ||
|           int tmp_val;
 | ||
|           for (int i = 0;
 | ||
|                i <= step && (hyps[k + 1].seq[i] > 0 || hyps[k].seq[i] > 0);
 | ||
|                i++) {
 | ||
|             tmp_val = hyps[k + 1].seq[i];
 | ||
|             hyps[k + 1].seq[i] = hyps[k].seq[i];
 | ||
|             hyps[k].seq[i] = tmp_val;
 | ||
|           }
 | ||
|         }
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| 
 | ||
|   __device__ __forceinline__ T get_worst_score() { return hyps[K - 1].score; }
 | ||
| };
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| struct TopK {
 | ||
|   int ids[K];
 | ||
|   T vals[K];
 | ||
|   int parent_ids[K];
 | ||
| 
 | ||
|   __device__ __forceinline__ void insert(T elem, int elem_id) {
 | ||
|     if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
 | ||
|         ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
 | ||
|       vals[K - 1] = elem;
 | ||
|       ids[K - 1] = elem_id;
 | ||
|     }
 | ||
| 
 | ||
|     for (int k = K - 2; k >= 0; --k) {
 | ||
|       if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
 | ||
|           ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
 | ||
|         T tmp_val = vals[k];
 | ||
|         int tmp_id = ids[k];
 | ||
|         vals[k] = vals[k + 1];
 | ||
|         ids[k] = ids[k + 1];
 | ||
|         vals[k + 1] = tmp_val;
 | ||
|         ids[k + 1] = tmp_id;
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| 
 | ||
|   __device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) {
 | ||
|     if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
 | ||
|         ((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
 | ||
|       vals[K - 1] = elem;
 | ||
|       ids[K - 1] = elem_id;
 | ||
|       parent_ids[K - 1] = parent_id;
 | ||
|     }
 | ||
| 
 | ||
|     for (int k = K - 2; k >= 0; --k) {
 | ||
|       if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
 | ||
|           ((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
 | ||
|         T tmp_val = vals[k];
 | ||
|         int tmp_id = ids[k];
 | ||
|         int parent_id2 = parent_ids[k];
 | ||
|         vals[k] = vals[k + 1];
 | ||
|         ids[k] = ids[k + 1];
 | ||
|         parent_ids[k] = parent_ids[k + 1];
 | ||
|         vals[k + 1] = tmp_val;
 | ||
|         ids[k + 1] = tmp_id;
 | ||
|         parent_ids[k + 1] = parent_id2;
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| };
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| __device__ __forceinline__ TopK<T, K> reduce_topk_op(const TopK<T, K> &a,
 | ||
|                                                      const TopK<T, K> &b) {
 | ||
|   TopK<T, K> res = a;
 | ||
|   for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]);
 | ||
|   return res;
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| struct TopKSoftMax {
 | ||
|   DySoftMaxStruct softmax_md;
 | ||
|   TopK<T, K> topk;
 | ||
| };
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| __device__ __forceinline__ TopKSoftMax<T, K> reduce_topk_softmax_op(
 | ||
|     const TopKSoftMax<T, K> &a, const TopKSoftMax<T, K> &b) {
 | ||
|   TopKSoftMax<T, K> res;
 | ||
|   // max_logit in block
 | ||
|   res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md);
 | ||
|   res.topk = reduce_topk_op(a.topk, b.topk);
 | ||
|   return res;
 | ||
| }
 | ||
| 
 | ||
| struct __align__(8) MD {
 | ||
|   float m;
 | ||
|   float d;
 | ||
| };
 | ||
| 
 | ||
| __device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
 | ||
|   bool const isABigger = a.m > b.m;
 | ||
|   MD const bigger = isABigger ? a : b;
 | ||
|   MD const smaller = isABigger ? b : a;
 | ||
|   MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)};
 | ||
|   return res;
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K, int THREADBLOCK_SIZE>
 | ||
| __launch_bounds__(THREADBLOCK_SIZE, 1) __global__
 | ||
|     void beam_search_softmax_topk_stage1_fast(const T *logits,
 | ||
|                                               float *tmp_buffer,
 | ||
|                                               const int *end_ids,
 | ||
|                                               const bool *beam_finished,
 | ||
|                                               const int *seq_lens,
 | ||
|                                               int beam_width,
 | ||
|                                               int beam_group_idx,
 | ||
|                                               int vocab_size,
 | ||
|                                               int vocab_chunk_size) {
 | ||
|   constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int tid = threadIdx.x;
 | ||
|   const int group_beam_batch_id = blockIdx.x;
 | ||
|   const int batch_id = group_beam_batch_id / beam_group_size;
 | ||
|   const int beam_group_sub_id = group_beam_batch_id % beam_group_size;
 | ||
|   const int beam_batch_id = batch_id * beam_width +
 | ||
|                             beam_group_idx * beam_group_size +
 | ||
|                             beam_group_sub_id;
 | ||
| 
 | ||
|   const int seq_len = seq_lens[beam_batch_id];
 | ||
|   const bool finished = beam_finished[beam_batch_id];
 | ||
| 
 | ||
|   if (seq_len < 0 || finished) {
 | ||
|     return;
 | ||
|   }
 | ||
| 
 | ||
|   const int section_start = vocab_chunk_size * blockIdx.y;
 | ||
|   const int section_end =
 | ||
|       std::min(section_start + vocab_chunk_size, vocab_size);
 | ||
|   const int valid_smem_length = section_end - section_start;
 | ||
|   T const MAX_T_VAL = 1e38;
 | ||
| 
 | ||
|   // Load element from logits to smemLogProbs, doing reduce_md and argmax
 | ||
|   // meanwhile Each thread is responsible for `vocab_chunk_size /
 | ||
|   // THREADBLOCK_SIZE` elements
 | ||
|   extern __shared__ char smem[];
 | ||
|   T *smemLogProbs = reinterpret_cast<T *>(smem);
 | ||
| 
 | ||
|   MD partial_md{-MAX_T_VAL, 0.0f};
 | ||
| 
 | ||
|   using KVPair = cub::KeyValuePair<int, T>;
 | ||
|   KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
 | ||
|   cub::ArgMax argmax;
 | ||
| 
 | ||
|   T const *local_logits = logits + beam_batch_id * vocab_size;
 | ||
| #pragma unroll 1
 | ||
|   for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) {
 | ||
|     T const val = local_logits[i];
 | ||
|     const int smem_index = i - section_start;
 | ||
|     smemLogProbs[smem_index] = val;
 | ||
|     MD new_elem_md{val, 1.0F};
 | ||
|     partial_md = reduce_md_op(partial_md, new_elem_md);
 | ||
|     KVPair new_elem_topk{smem_index, val};
 | ||
|     topKVPairPartial = argmax(topKVPairPartial, new_elem_topk);
 | ||
|   }
 | ||
|   __syncthreads();
 | ||
| 
 | ||
|   // Search the top 2K elements among `vocab_chunk_size` elements of this
 | ||
|   // ThreadBlock and write into smemOutput
 | ||
|   __shared__ float smemOutput[PACKED_TOP_KMD_SIZE];
 | ||
|   __shared__ int threadToUpdate;
 | ||
| 
 | ||
|   using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
 | ||
|   using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
 | ||
| 
 | ||
|   __shared__ union {
 | ||
|     typename BlockReduceTopK::TempStorage topk;
 | ||
|     typename BlockReduceMD::TempStorage md;
 | ||
|   } smemReduceBuffer;
 | ||
| 
 | ||
|   for (int i = 0; i < 2 * beam_group_size; ++i) {
 | ||
|     // Pop the element with largest value to "smemOutput" per iteration
 | ||
|     KVPair topKVPair =
 | ||
|         BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
 | ||
|     if (tid == 0) {
 | ||
|       // const int index = beam_batch_id * vocab_size + section_start +
 | ||
|       const int index = section_start + topKVPair.key;
 | ||
|       reinterpret_cast<int *>(smemOutput)[i] = index;
 | ||
|       smemOutput[K + i] = topKVPair.value;
 | ||
|       smemLogProbs[topKVPair.key] =
 | ||
|           -MAX_T_VAL;  // pollute the value of the popped element
 | ||
|       threadToUpdate = topKVPair.key % THREADBLOCK_SIZE;
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     if (tid == threadToUpdate && i < 2 * beam_group_size - 1) {
 | ||
|       // The thread popped the element need to update its topKVPairPartial
 | ||
|       // No need to do this in the last iteration
 | ||
|       topKVPairPartial.key = vocab_size - 1;
 | ||
|       topKVPairPartial.value = -MAX_T_VAL;
 | ||
|       for (int index = tid; index < valid_smem_length;
 | ||
|            index += THREADBLOCK_SIZE) {
 | ||
|         topKVPairPartial =
 | ||
|             argmax(topKVPairPartial, {index, smemLogProbs[index]});
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| 
 | ||
|   // Do reduce_md among the top 2K elements in the smemOutput and write into
 | ||
|   // tail of smemOutput
 | ||
|   auto reduce_md_func = [](const MD &a, const MD &b) {
 | ||
|     return reduce_md_op(a, b);
 | ||
|   };
 | ||
|   MD total_md =
 | ||
|       BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
 | ||
|   if (tid == 0) {
 | ||
|     smemOutput[2 * K] = total_md.d;
 | ||
|     smemOutput[2 * K + 1] = total_md.m;
 | ||
|   }
 | ||
|   __syncthreads();
 | ||
| 
 | ||
|   // Write the smemOutput into tmp_buffer
 | ||
|   float *local_temp_buffer =
 | ||
|       tmp_buffer + group_beam_batch_id * PACKED_TOP_KMD_SIZE * gridDim.y +
 | ||
|       blockIdx.y * PACKED_TOP_KMD_SIZE;
 | ||
| #pragma unroll
 | ||
|   for (int i = tid; i < PACKED_TOP_KMD_SIZE; i += THREADBLOCK_SIZE) {
 | ||
|     local_temp_buffer[i] = smemOutput[i];
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| // <<<(batch_size * beam_group_size, voc_parts), 128>>>
 | ||
| template <typename T, int K, int THREADBLOCK_SIZE, int PACKED_TOP_KMD_SIZE>
 | ||
| __global__ void beam_search_softmax_topk_stage1(BeamSearchParams<T> params,
 | ||
|                                                 const int beam_width,
 | ||
|                                                 const int beam_group_idx,
 | ||
|                                                 const int vocab_size,
 | ||
|                                                 const bool fuse_softmax) {
 | ||
|   const int thread_id = threadIdx.x;
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int batch_id = blockIdx.x / beam_group_size;
 | ||
|   const int beam_group_sub_idx = blockIdx.x % beam_group_size;
 | ||
|   const int beam_batch_id = batch_id * beam_width +
 | ||
|                             beam_group_idx * beam_group_size +
 | ||
|                             beam_group_sub_idx;
 | ||
| 
 | ||
|   const bool finish = params.beam_finished[beam_batch_id];
 | ||
|   const int seq_len = params.seq_lens[beam_batch_id];
 | ||
| 
 | ||
|   // for dybatch
 | ||
|   if (seq_len < 0 || finish) {
 | ||
|     return;
 | ||
|   }
 | ||
| 
 | ||
|   // 2 * K + 2
 | ||
|   __shared__ float buf_s[PACKED_TOP_KMD_SIZE];
 | ||
| 
 | ||
|   const T MAX_T_VAL = FLT_MAX;
 | ||
| 
 | ||
|   const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
 | ||
|   const int section_start = v_local * blockIdx.y;
 | ||
|   int section_end = section_start + v_local;
 | ||
|   section_end = (section_end > vocab_size) ? vocab_size : section_end;
 | ||
| 
 | ||
|   T *logits = params.logits + beam_batch_id * vocab_size;
 | ||
| 
 | ||
|   if (fuse_softmax) {
 | ||
|     typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
 | ||
|     __shared__ typename BlockReduce::TempStorage temp_storage;
 | ||
| 
 | ||
|     TopKSoftMax<T, K> partial;
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       partial.topk.ids[i] = -1;
 | ||
|       partial.topk.vals[i] = -MAX_T_VAL;
 | ||
|     }
 | ||
|     partial.softmax_md.logit = -MAX_T_VAL;
 | ||
|     partial.softmax_md.score = 0.0F;
 | ||
| 
 | ||
| // process voc_parts
 | ||
| #pragma unroll 1
 | ||
|     for (int elem_id = section_start + thread_id; elem_id < section_end;
 | ||
|          elem_id += THREADBLOCK_SIZE) {
 | ||
|       T elem = logits[elem_id];
 | ||
|       DySoftMaxStruct new_elem{elem, 1.0F};
 | ||
|       partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem);
 | ||
|       partial.topk.insert(elem, elem_id);
 | ||
|     }
 | ||
|     // === old_beam_search strategy ===
 | ||
|     // }
 | ||
| 
 | ||
|     // reduce voc_parts
 | ||
|     TopKSoftMax<T, K> total =
 | ||
|         BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
 | ||
| 
 | ||
|     if (thread_id == 0) {
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         reinterpret_cast<int *>(buf_s)[i] = total.topk.ids[i];
 | ||
|         buf_s[K + i] = total.topk.vals[i];
 | ||
|       }
 | ||
|       buf_s[2 * K] = total.softmax_md.score;
 | ||
|       buf_s[2 * K + 1] = total.softmax_md.logit;
 | ||
|     }
 | ||
|   } else {
 | ||
|     typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
 | ||
|     __shared__ typename BlockReduce::TempStorage temp_storage;
 | ||
| 
 | ||
|     TopK<T, K> partial;
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       partial.ids[i] = -1;
 | ||
|       partial.vals[i] = -MAX_T_VAL;
 | ||
|     }
 | ||
| 
 | ||
| #pragma unroll 1
 | ||
|     for (int elem_id = section_start + thread_id; elem_id < section_end;
 | ||
|          elem_id += THREADBLOCK_SIZE) {
 | ||
|       T elem = logits[elem_id];
 | ||
|       partial.insert(elem, elem_id);
 | ||
|     }
 | ||
| 
 | ||
|     TopK<T, K> total =
 | ||
|         BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
 | ||
| 
 | ||
|     if (thread_id == 0) {
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         reinterpret_cast<int *>(buf_s)[i] = total.ids[i];
 | ||
|         buf_s[K + i] = total.vals[i];
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
|   __syncthreads();
 | ||
|   // write all the voc_parts results to tmp_buffer
 | ||
|   for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE;
 | ||
|        elem_id += THREADBLOCK_SIZE) {
 | ||
|     params.tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y +
 | ||
|                       blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] =
 | ||
|         buf_s[elem_id];
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K, int THREADBLOCK_SIZE, bool IS_FAST_KERNEL>
 | ||
| __launch_bounds__(THREADBLOCK_SIZE) __global__
 | ||
|     void beam_search_softmax_topk_stage2_fast(
 | ||
|         int *__restrict tmp_ids,
 | ||
|         T *__restrict tmp_vals,
 | ||
|         float *__restrict tmp_buffer,
 | ||
|         const float *__restrict cum_scores,
 | ||
|         const bool *__restrict beam_finished,
 | ||
|         const int *__restrict seq_lens,
 | ||
|         const int beam_width,
 | ||
|         const int beam_group_idx,
 | ||
|         const int vocab_size,
 | ||
|         const int voc_parts) {
 | ||
|   constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
 | ||
|   constexpr int beam_group_size = K / 2;
 | ||
|   const int group_beam_batch_id = blockIdx.x;
 | ||
|   const int beam_group_sub_id = blockIdx.x % beam_group_size;
 | ||
|   const int batch_size = group_beam_batch_id / beam_group_size;
 | ||
|   const int beam_batch_id = batch_size * beam_width +
 | ||
|                             beam_group_idx * beam_group_size +
 | ||
|                             beam_group_sub_id;
 | ||
| 
 | ||
|   if (seq_lens[beam_batch_id] < 0 || beam_finished[beam_batch_id]) {
 | ||
|     return;
 | ||
|   }
 | ||
| 
 | ||
|   const int tid = threadIdx.x;
 | ||
|   T const MAX_T_VAL = FLT_MAX;
 | ||
| 
 | ||
|   using KVPair = cub::KeyValuePair<int, T>;
 | ||
|   using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
 | ||
|   using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
 | ||
| 
 | ||
|   __shared__ KVPair buf_smem_kv[K];
 | ||
| 
 | ||
|   __shared__ union {
 | ||
|     typename BlockReduceTopK::TempStorage topk;
 | ||
|     typename BlockReduceMD::TempStorage md;
 | ||
|   } smemReduceBuffer;
 | ||
| 
 | ||
|   cub::ArgMax argmax;
 | ||
|   MD partial_md{-MAX_T_VAL, 0.0f};
 | ||
|   KVPair topKVPair{vocab_size - 1, -MAX_T_VAL};
 | ||
| 
 | ||
|   auto reduce_md_func = [](const MD &a, const MD &b) {
 | ||
|     return reduce_md_op(a, b);
 | ||
|   };
 | ||
| 
 | ||
|   // Load and unpack into registers through smem
 | ||
|   float *localTempBuffer =
 | ||
|       tmp_buffer + PACKED_TOP_KMD_SIZE * group_beam_batch_id * voc_parts;
 | ||
|   if constexpr (IS_FAST_KERNEL) {  // Use share memory instead of global memory
 | ||
|     extern __shared__ char smem[];
 | ||
|     float *smemVal = reinterpret_cast<float *>(smem);
 | ||
|     for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * voc_parts;
 | ||
|          idx += THREADBLOCK_SIZE) {
 | ||
|       smemVal[idx] = localTempBuffer[idx];
 | ||
|     }
 | ||
|     localTempBuffer = smemVal;
 | ||
|     __syncthreads();
 | ||
|   }
 | ||
| 
 | ||
|   // Find the top 2K across all voc_parts
 | ||
|   for (int k = 0; k < K; ++k) {
 | ||
|     KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
 | ||
|     // Only threads responsible for a chunk will do the computation
 | ||
|     if (tid < voc_parts) {
 | ||
|       for (int i = 0; i < K; ++i) {
 | ||
|         const int current_index = tid * PACKED_TOP_KMD_SIZE + i;
 | ||
|         T topValue = localTempBuffer[current_index + K];
 | ||
|         topKVPairPartial = argmax(topKVPairPartial, {current_index, topValue});
 | ||
|       }
 | ||
|     }
 | ||
| 
 | ||
|     KVPair topKVPair =
 | ||
|         BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     if (tid == 0) {
 | ||
|       // Store kv pairs in shared mem buffer
 | ||
|       int temp_offset = topKVPair.key;
 | ||
|       int global_offset = reinterpret_cast<int *>(localTempBuffer)[temp_offset];
 | ||
|       topKVPair.key = global_offset;
 | ||
|       buf_smem_kv[k] = topKVPair;
 | ||
| 
 | ||
|       // Invalidate the maximum value within the chunk
 | ||
|       reinterpret_cast<int *>(localTempBuffer)[temp_offset] =
 | ||
|           vocab_size - 1;                             // id in share memory
 | ||
|       localTempBuffer[temp_offset + K] = -MAX_T_VAL;  // value in share memory
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
|   }
 | ||
| 
 | ||
|   // Extract and reduce MD values across the chunks
 | ||
|   if (tid < voc_parts) {
 | ||
|     partial_md.d = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K];
 | ||
|     partial_md.m = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K + 1];
 | ||
|   }
 | ||
|   __syncthreads();
 | ||
| 
 | ||
|   MD total_md =
 | ||
|       BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
 | ||
| 
 | ||
|   if (tid == 0) {
 | ||
|     float d_total_log = logf(total_md.d);
 | ||
| 
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       float val =
 | ||
|           static_cast<float>(buf_smem_kv[i].value) - total_md.m - d_total_log;
 | ||
|       tmp_ids[group_beam_batch_id * K + i] = buf_smem_kv[i].key;
 | ||
|       tmp_vals[group_beam_batch_id * K + i] = val + cum_scores[beam_batch_id];
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| #define BEAM_STAGE2_KERNEL(N_VOCAB_PART, IS_FAST_KERNEL)                     \
 | ||
|   do {                                                                       \
 | ||
|     if (IS_FAST_KERNEL && nShareMemory >= (48 << 10)) {                      \
 | ||
|       cudaFuncSetAttribute(                                                  \
 | ||
|           beam_search_softmax_topk_stage2_fast<T,                            \
 | ||
|                                                K,                            \
 | ||
|                                                N_VOCAB_PART,                 \
 | ||
|                                                IS_FAST_KERNEL>,              \
 | ||
|           cudaFuncAttributeMaxDynamicSharedMemorySize,                       \
 | ||
|           nShareMemory);                                                     \
 | ||
|     }                                                                        \
 | ||
|     beam_search_softmax_topk_stage2_fast<T, K, N_VOCAB_PART, IS_FAST_KERNEL> \
 | ||
|         <<<batch_size * beam_group_size,                                     \
 | ||
|            N_VOCAB_PART,                                                     \
 | ||
|            IS_FAST_KERNEL * nShareMemory,                                    \
 | ||
|            stream>>>(params->tmp_ids,                                        \
 | ||
|                      params->tmp_vals,                                       \
 | ||
|                      params->tmp_buffer,                                     \
 | ||
|                      params->cum_scores,                                     \
 | ||
|                      params->beam_finished,                                  \
 | ||
|                      params->seq_lens,                                       \
 | ||
|                      beam_width,                                             \
 | ||
|                      beam_group_idx,                                         \
 | ||
|                      vocab_size,                                             \
 | ||
|                      voc_parts);                                             \
 | ||
|   } while (0);                                                               \
 | ||
|   return;
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| __inline__ void beamSearchSoftmaxTopkStage2FastKernelLauncher(
 | ||
|     BeamSearchParams<T> *params,
 | ||
|     const int batch_size,
 | ||
|     const int beam_width,
 | ||
|     const int beam_group_idx,
 | ||
|     const int vocab_size,
 | ||
|     const int voc_parts,
 | ||
|     const int max_smem_per_block,
 | ||
|     cudaStream_t stream) {
 | ||
|   constexpr int beam_group_size = K / 2;
 | ||
|   size_t const nShareMemory = sizeof(float) * voc_parts * (2 * K + 2) +
 | ||
|                               sizeof(cub::KeyValuePair<int, T>) * K;
 | ||
|   if (nShareMemory < max_smem_per_block) {  // IS_FAST_KERNEL must be a
 | ||
|                                             // compilation-time constant
 | ||
|     if (voc_parts <= 32) {
 | ||
|       BEAM_STAGE2_KERNEL(32, true)
 | ||
|     }
 | ||
|     if (voc_parts <= 64) {
 | ||
|       BEAM_STAGE2_KERNEL(64, true)
 | ||
|     }
 | ||
|     BEAM_STAGE2_KERNEL(128, true)
 | ||
|     // No larger branch since voc_parts <= nMaxVocabPartForStage1FastKernel
 | ||
|   }
 | ||
|   BEAM_STAGE2_KERNEL(128, false)
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K, int THREADBLOCK_SIZE>
 | ||
| __global__ void beam_search_softmax_topk_stage2(BeamSearchParams<T> params,
 | ||
|                                                 const int beam_width,
 | ||
|                                                 const int beam_group_idx,
 | ||
|                                                 const int voc_parts,
 | ||
|                                                 const int packed_top_kmd_size,
 | ||
|                                                 const bool fuse_softmax) {
 | ||
|   const int thread_id = threadIdx.x;
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int batch_id = blockIdx.x / beam_group_size;
 | ||
|   const int beam_group_sub_idx = blockIdx.x % beam_group_size;
 | ||
|   // int vector_id = blockIdx.x;  // batch beam index.
 | ||
|   const int beam_batch_id = batch_id * beam_width +
 | ||
|                             beam_group_idx * beam_group_size +
 | ||
|                             beam_group_sub_idx;
 | ||
|   const int group_beam_batch_id = blockIdx.x;
 | ||
|   // const int vector_id = blockIdx.x;
 | ||
|   const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size;
 | ||
| 
 | ||
|   // for dybatch
 | ||
|   const int seq_len = params.seq_lens[beam_batch_id];
 | ||
|   const bool finish = params.beam_finished[beam_batch_id];
 | ||
| 
 | ||
|   int *tmp_ids = params.tmp_ids + group_beam_batch_id * K;
 | ||
|   float *tmp_vals = params.tmp_vals + group_beam_batch_id * K;
 | ||
|   float *tmp_buffer = params.tmp_buffer;
 | ||
| 
 | ||
|   const T *cum_scores = params.cum_scores + beam_batch_id;
 | ||
|   if (seq_len < 0 || finish) {
 | ||
|     return;
 | ||
|   }
 | ||
|   const T MAX_T_VAL = FLT_MAX;
 | ||
| 
 | ||
|   extern __shared__ char buf_s_[];
 | ||
|   float *buf_s = reinterpret_cast<float *>(buf_s_);
 | ||
|   // 当前 batch beam 的所有 voc
 | ||
|   tmp_buffer += group_beam_batch_id * PACKED_TOP_KMD_SIZE * voc_parts;
 | ||
| 
 | ||
|   if (fuse_softmax) {
 | ||
|     typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
 | ||
|     __shared__ typename BlockReduce::TempStorage temp_storage;
 | ||
|     TopKSoftMax<T, K> partial;
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       partial.topk.ids[i] = -1;
 | ||
|       partial.topk.vals[i] = -MAX_T_VAL;
 | ||
|     }
 | ||
|     partial.softmax_md.logit = -MAX_T_VAL;
 | ||
|     partial.softmax_md.score = 0.0F;
 | ||
| 
 | ||
|     for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
 | ||
|          idx += THREADBLOCK_SIZE) {
 | ||
|       buf_s[idx] = tmp_buffer[idx];
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     if (threadIdx.x < voc_parts) {
 | ||
|       float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         partial.topk.ids[i] = reinterpret_cast<int *>(b_s)[i];
 | ||
|         partial.topk.vals[i] = b_s[K + i];
 | ||
|       }
 | ||
|       partial.softmax_md.score = b_s[2 * K];
 | ||
|       partial.softmax_md.logit = b_s[2 * K + 1];
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     TopKSoftMax<T, K> total =
 | ||
|         BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
 | ||
| 
 | ||
|     if (thread_id == 0) {
 | ||
|       // tmp_ids += group_beam_batch_id * K;
 | ||
|       // tmp_vals += group_beam_batch_id * K;
 | ||
| 
 | ||
|       float d_total_log = logf(total.softmax_md.score);
 | ||
|       for (int i = 0; i < K; ++i) {
 | ||
|         // float val = expf((float)total.topk.vals[i] - total.softmax_md.logit -
 | ||
|         // d_total_log);
 | ||
|         float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log;
 | ||
|         tmp_ids[i] = total.topk.ids[i];
 | ||
|         tmp_vals[i] = val + params.cum_scores[beam_batch_id];
 | ||
|       }
 | ||
|     }
 | ||
|   } else {
 | ||
|     typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
 | ||
|     __shared__ typename BlockReduce::TempStorage temp_storage;
 | ||
| 
 | ||
|     TopK<T, K> partial;
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       partial.ids[i] = -1;
 | ||
|       partial.vals[i] = -MAX_T_VAL;
 | ||
|     }
 | ||
| 
 | ||
|     for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
 | ||
|          idx += THREADBLOCK_SIZE) {
 | ||
|       buf_s[idx] = tmp_buffer[idx];
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     if (threadIdx.x < voc_parts) {
 | ||
|       float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         partial.ids[i] = reinterpret_cast<int *>(b_s)[i];
 | ||
|         partial.vals[i] = b_s[K + i];
 | ||
|       }
 | ||
|     }
 | ||
|     __syncthreads();
 | ||
| 
 | ||
|     TopK<T, K> total =
 | ||
|         BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
 | ||
| 
 | ||
|     if (thread_id == 0) {
 | ||
|       tmp_ids += group_beam_batch_id * K;
 | ||
|       tmp_vals += group_beam_batch_id * K;
 | ||
| 
 | ||
|       for (int i = 0; i < K; ++i) {
 | ||
|         float val = total.vals[i];
 | ||
|         tmp_ids[i] = total.ids[i];
 | ||
|         tmp_vals[i] = val + params.cum_scores[beam_batch_id];
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| void invokeBeamSearchSoftmaxTopKStage2(BeamSearchParams<T> *params,
 | ||
|                                        const int batch_size,
 | ||
|                                        const int beam_width,
 | ||
|                                        const int beam_group_idx,
 | ||
|                                        const int voc_parts,
 | ||
|                                        const int packed_top_kmd_size,
 | ||
|                                        const bool fuse_softmax,
 | ||
|                                        gpuStream_t stream) {
 | ||
|   int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float);
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   if (voc_parts <= 32) {
 | ||
|     beam_search_softmax_topk_stage2<T, K, 32>
 | ||
|         <<<batch_size * beam_group_size, 32, smem_stage2_size, stream>>>(
 | ||
|             *params,
 | ||
|             beam_width,
 | ||
|             beam_group_idx,
 | ||
|             voc_parts,
 | ||
|             packed_top_kmd_size,
 | ||
|             fuse_softmax);
 | ||
|     return;
 | ||
|   }
 | ||
|   if (voc_parts <= 64) {
 | ||
|     beam_search_softmax_topk_stage2<T, K, 64>
 | ||
|         <<<batch_size * beam_group_size, 64, smem_stage2_size, stream>>>(
 | ||
|             *params,
 | ||
|             beam_width,
 | ||
|             beam_group_idx,
 | ||
|             voc_parts,
 | ||
|             packed_top_kmd_size,
 | ||
|             fuse_softmax);
 | ||
|     return;
 | ||
|   }
 | ||
|   if (voc_parts <= 128) {
 | ||
|     beam_search_softmax_topk_stage2<T, K, 128>
 | ||
|         <<<batch_size * beam_group_size, 128, smem_stage2_size, stream>>>(
 | ||
|             *params,
 | ||
|             beam_width,
 | ||
|             beam_group_idx,
 | ||
|             voc_parts,
 | ||
|             packed_top_kmd_size,
 | ||
|             fuse_softmax);
 | ||
|     return;
 | ||
|   }
 | ||
|   if (voc_parts <= 256) {
 | ||
|     beam_search_softmax_topk_stage2<T, K, 256>
 | ||
|         <<<batch_size * beam_group_size, 256, smem_stage2_size, stream>>>(
 | ||
|             *params,
 | ||
|             beam_width,
 | ||
|             beam_group_idx,
 | ||
|             voc_parts,
 | ||
|             packed_top_kmd_size,
 | ||
|             fuse_softmax);
 | ||
|     return;
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| __global__ void update_beam_finished_early_stop(const T *beam_hyps_score_out,
 | ||
|                                                 bool *beam_finished) {
 | ||
|   if (threadIdx.x == 0) {
 | ||
|     int batch_idx = blockIdx.x;
 | ||
| 
 | ||
|     const T *cur_beam_hyps_score = beam_hyps_score_out + batch_idx * K;
 | ||
|     bool *cur_beam_finished = beam_finished + batch_idx * K;
 | ||
|     if (cur_beam_hyps_score[K - 1] > -1e8) {
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         cur_beam_finished[i] = true;
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| // <<<batch_size>>>
 | ||
| template <typename T, int K, int THREADBLOCK_SIZE, bool GROUP>
 | ||
| __global__ void batch_topk(BeamSearchParams<T> params,
 | ||
|                            const int beam_width,
 | ||
|                            const int beam_group_idx,
 | ||
|                            const int dec_stride) {
 | ||
|   const bool early_stop = params.early_stop;
 | ||
|   const int thread_id = threadIdx.x;
 | ||
|   const int batch_id = blockIdx.x;
 | ||
|   // int block_id = blockIdx.x;  // bs
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int beam_group_start_id =
 | ||
|       batch_id * beam_width + beam_group_idx * beam_group_size;
 | ||
| 
 | ||
|   bool *beam_finished = params.beam_finished + beam_group_start_id;
 | ||
|   const int *step_ids = params.step_ids + beam_group_start_id;
 | ||
|   int *next_tokens = params.next_tokens + beam_group_start_id;
 | ||
|   float *cum_scores_out = params.cum_scores_out + beam_group_start_id;
 | ||
|   int *parent_ids = params.parent_ids + beam_group_start_id;
 | ||
|   float *beam_hyps_score_out = params.beam_hyps_score_out + beam_group_start_id;
 | ||
| 
 | ||
|   const bool finish = beam_finished[0];
 | ||
|   const int step_id = step_ids[0];
 | ||
|   const int seq_len = params.seq_lens[beam_group_start_id];
 | ||
|   const int max_dec_len = params.max_dec_lens[beam_group_start_id];
 | ||
|   const bool last_dec_step = (step_id + 1 == max_dec_len);
 | ||
| 
 | ||
|   if (thread_id == 0 && seq_len > 0 && !finish) {
 | ||
|     TopK<T, K> partial;
 | ||
|     BeamHypothesesTopK<T, K / 2> beam_hyps;
 | ||
| 
 | ||
|     beam_hyps.init(params.beam_hyps_out + beam_group_start_id * dec_stride,
 | ||
|                    params.beam_hyps_score_out + beam_group_start_id,
 | ||
|                    dec_stride);
 | ||
| 
 | ||
|     for (int i = 0; i < K; ++i) {
 | ||
|       partial.ids[i] = -1;
 | ||
|       partial.vals[i] = -FLT_MAX;
 | ||
|       partial.parent_ids[i] = -1;
 | ||
|     }
 | ||
|     int index = batch_id * beam_group_size * K;
 | ||
|     if (step_id == 0) {
 | ||
|       for (int i = 0; i < K; i++) {
 | ||
|         float score_now = apply_length_penalty(params.tmp_vals[index + i],
 | ||
|                                                step_id + 1,
 | ||
|                                                params.length_penalty[batch_id]);
 | ||
|         if (!GROUP) {
 | ||
|           score_now -=
 | ||
|               params.diversity_penalty[batch_id] * static_cast<float>(i + 1);
 | ||
|         }
 | ||
|         partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
 | ||
|       }
 | ||
|     } else {
 | ||
|       for (int i = 0; i < beam_group_size * K; i++) {
 | ||
|         float score_now = apply_length_penalty(params.tmp_vals[index + i],
 | ||
|                                                step_id + 1,
 | ||
|                                                params.length_penalty[batch_id]);
 | ||
|         if (!GROUP) {
 | ||
|           score_now -= params.diversity_penalty[batch_id] *
 | ||
|                        static_cast<float>(i % K + 1);
 | ||
|         }
 | ||
|         partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
 | ||
|       }
 | ||
|     }
 | ||
|     if (partial.vals[0] < beam_hyps.hyps[beam_group_size - 1].score) {
 | ||
|       for (int i = 0; i < beam_group_size; i++) {
 | ||
|         beam_finished[i] = true;
 | ||
|       }
 | ||
|       return;
 | ||
|     }
 | ||
| 
 | ||
|     int next_step_num = 0;
 | ||
|     for (int i = 0; i < K && next_step_num < beam_group_size; i++) {
 | ||
|       int parent_id = partial.parent_ids[i];
 | ||
|       if (is_in_end(partial.ids[i], params.end_ids, params.end_ids_len) ||
 | ||
|           last_dec_step) {
 | ||
|         if (i < beam_group_size &&
 | ||
|             partial.vals[i] > beam_hyps.get_worst_score()) {
 | ||
|           const int *beam_cache_id = params.beam_cache_ids +
 | ||
|                                      beam_group_start_id * dec_stride +
 | ||
|                                      parent_id * dec_stride;
 | ||
|           beam_hyps.insert(beam_cache_id,
 | ||
|                            step_id,
 | ||
|                            last_dec_step ? params.end_ids[0] : partial.ids[i],
 | ||
|                            partial.vals[i]);
 | ||
|         }
 | ||
| 
 | ||
|         if (early_stop && beam_hyps.get_worst_score() > -1e8) {
 | ||
|           // stop
 | ||
|           for (int i = 0; i < beam_group_size; i++) {
 | ||
|             beam_finished[i] = true;
 | ||
|           }
 | ||
|           return;
 | ||
|         }
 | ||
|       } else {
 | ||
|         next_tokens[next_step_num] = partial.ids[i];
 | ||
|         cum_scores_out[next_step_num] = partial.vals[i];
 | ||
|         parent_ids[next_step_num] = parent_id;
 | ||
|         next_step_num += 1;
 | ||
|       }
 | ||
|     }
 | ||
| 
 | ||
|     for (int i = 0; i < beam_group_size; i++) {
 | ||
|       beam_hyps_score_out[i] = beam_hyps.hyps[i].score;
 | ||
|     }
 | ||
| 
 | ||
|     if (last_dec_step) {
 | ||
|       for (int i = 0; i < beam_group_size; i++) {
 | ||
|         beam_finished[i] = true;
 | ||
|       }
 | ||
|     }
 | ||
|   }  // if (thread_id == 0)
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K, bool GROUP>
 | ||
| void invokeTopKSoftMaxLauncher(BeamSearchParams<T> *params,
 | ||
|                                int beam_group_idx,
 | ||
|                                gpuStream_t stream) {
 | ||
|   const int batch_size = params->batch_size;
 | ||
|   const int beam_width = params->beam_width;
 | ||
|   const int beam_group_size = K / 2;
 | ||
|   const int vocab_size = params->vocab_size;
 | ||
|   const bool fuse_softmax = params->fuse_softmax;
 | ||
|   const int voc_parts = params->voc_parts;
 | ||
|   constexpr int dev_id = 0;
 | ||
| 
 | ||
|   // only in group_beam_search
 | ||
|   if (beam_width > beam_group_size && beam_group_idx != 0) {
 | ||
|     apply_group_diversity_penalty<T, K>
 | ||
|         <<<batch_size, beam_group_size, 0, stream>>>(
 | ||
|             *params, batch_size, beam_width, beam_group_idx, vocab_size);
 | ||
|   }
 | ||
| 
 | ||
|   // == Step1 == : stage1
 | ||
|   if (params->use_fast_kernel) {
 | ||
|     constexpr int block_size =
 | ||
|         (K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64;
 | ||
|     const int vocab_chunk_size = (vocab_size + voc_parts - 1) / voc_parts;
 | ||
|     const int dyn_smem_size = sizeof(T) * vocab_chunk_size;
 | ||
|     if (dyn_smem_size >= (48 << 10)) {
 | ||
|       cudaFuncSetAttribute(
 | ||
|           beam_search_softmax_topk_stage1_fast<T, K, block_size>,
 | ||
|           cudaFuncAttributeMaxDynamicSharedMemorySize,
 | ||
|           dyn_smem_size);
 | ||
|     }
 | ||
| 
 | ||
|     dim3 grid(batch_size * beam_group_size, voc_parts);
 | ||
|     beam_search_softmax_topk_stage1_fast<T, K, block_size>
 | ||
|         <<<grid, block_size, dyn_smem_size, stream>>>(params->logits,
 | ||
|                                                       params->tmp_buffer,
 | ||
|                                                       params->end_ids,
 | ||
|                                                       params->beam_finished,
 | ||
|                                                       params->seq_lens,
 | ||
|                                                       beam_width,
 | ||
|                                                       beam_group_idx,
 | ||
|                                                       vocab_size,
 | ||
|                                                       vocab_chunk_size);
 | ||
|   } else {
 | ||
|     constexpr int block_size = 128;
 | ||
|     dim3 grid(batch_size * beam_group_size, voc_parts);
 | ||
|     cudaFuncSetAttribute(
 | ||
|         beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>,
 | ||
|         cudaFuncAttributePreferredSharedMemoryCarveout,
 | ||
|         cudaSharedmemCarveoutMaxL1);
 | ||
|     if (fuse_softmax) {
 | ||
| #ifdef PADDLE_WITH_CUDA
 | ||
|       cudaFuncSetAttribute(
 | ||
|           beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>,
 | ||
|           cudaFuncAttributePreferredSharedMemoryCarveout,
 | ||
|           cudaSharedmemCarveoutMaxL1);
 | ||
| #else
 | ||
|       // cudaSharedmemCarveoutMaxL1 equal to 0
 | ||
|       hipFuncSetAttribute(
 | ||
|           reinterpret_cast<void *>(
 | ||
|               beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>),
 | ||
|           hipFuncAttributePreferredSharedMemoryCarveout,
 | ||
|           0);
 | ||
| #endif
 | ||
|       // (bs, bm, voc_parts, 2 * K + 2)
 | ||
|       beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>
 | ||
|           <<<grid, block_size, 0, stream>>>(
 | ||
|               *params, beam_width, beam_group_idx, vocab_size, fuse_softmax);
 | ||
|     } else {
 | ||
| #ifdef PADDLE_WITH_CUDA
 | ||
|       cudaFuncSetAttribute(
 | ||
|           beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>,
 | ||
|           cudaFuncAttributePreferredSharedMemoryCarveout,
 | ||
|           cudaSharedmemCarveoutMaxL1);
 | ||
| #else
 | ||
|       // cudaSharedmemCarveoutMaxL1 equal to 0
 | ||
|       hipFuncSetAttribute(
 | ||
|           reinterpret_cast<void *>(
 | ||
|               beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>),
 | ||
|           hipFuncAttributePreferredSharedMemoryCarveout,
 | ||
|           0);
 | ||
| #endif
 | ||
|       // (bs, bm, voc_parts, 2 * K)
 | ||
|       beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>
 | ||
|           <<<grid, block_size, 0, stream>>>(
 | ||
|               *params, beam_width, beam_group_idx, vocab_size, fuse_softmax);
 | ||
|     }
 | ||
|   }
 | ||
| 
 | ||
|   beamSearchSoftmaxTopkStage2FastKernelLauncher<float, K>(
 | ||
|       params,
 | ||
|       batch_size,
 | ||
|       beam_width,
 | ||
|       beam_group_idx,
 | ||
|       vocab_size,
 | ||
|       voc_parts,
 | ||
|       params->max_smem_per_block,
 | ||
|       stream);
 | ||
| 
 | ||
|   batch_topk<T, K, 32, GROUP><<<batch_size, 32, 0, stream>>>(
 | ||
|       *params, beam_width, beam_group_idx, params->dec_stride);
 | ||
| }
 | ||
| 
 | ||
| template <typename T, bool GROUP>
 | ||
| void invokeTopkSoftMax(BeamSearchParams<T> *params,
 | ||
|                        int beam_group_idx,
 | ||
|                        gpuStream_t stream) {
 | ||
|   switch (params->beam_group_size) {
 | ||
|     CASE_K(1);
 | ||
|     CASE_K(2);
 | ||
|     CASE_K(3);
 | ||
|     CASE_K(4);
 | ||
|     CASE_K(5);
 | ||
|     CASE_K(6);
 | ||
|     CASE_K(7);
 | ||
|     CASE_K(8);
 | ||
|     CASE_K(9);
 | ||
|     CASE_K(10);
 | ||
|     CASE_K(11);
 | ||
|     CASE_K(12);
 | ||
|     CASE_K(13);
 | ||
|     CASE_K(14);
 | ||
|     CASE_K(15);
 | ||
|     CASE_K(16);
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T, int K>
 | ||
| void ComputeVocParts(BeamSearchParams<T> *params) {
 | ||
|   int dev_id = 0;
 | ||
|   const int block_size =
 | ||
|       (K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64;
 | ||
|   int max_active_blocks = -1;
 | ||
|   cudaOccupancyMaxActiveBlocksPerMultiprocessor(
 | ||
|       &max_active_blocks,
 | ||
|       beam_search_softmax_topk_stage1_fast<float, K, block_size>,
 | ||
|       block_size,
 | ||
|       0);
 | ||
| 
 | ||
|   int max_smem_per_sm = -1;
 | ||
|   int max_smem_per_block = -1;
 | ||
|   cudaDeviceGetAttribute(
 | ||
|       &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id);
 | ||
|   cudaDeviceGetAttribute(
 | ||
|       &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id);
 | ||
|   cudaFuncAttributes attr;
 | ||
|   cudaFuncGetAttributes(
 | ||
|       &attr, beam_search_softmax_topk_stage1_fast<float, K, block_size>);
 | ||
|   const int static_smem = attr.sharedSizeBytes;
 | ||
|   const int max_dyn_smem_per_block = max_smem_per_block - static_smem;
 | ||
| 
 | ||
|   if (sizeof(T) * params->vocab_size >
 | ||
|       max_dyn_smem_per_block * kMaxVocabPartForStage1FastKernel) {
 | ||
|   }
 | ||
| 
 | ||
|   const int driver_smem_per_block = max_smem_per_sm - max_smem_per_block;
 | ||
|   const int extra_smem = driver_smem_per_block + static_smem;
 | ||
|   int voc_parts = kMaxVocabPartForStage1FastKernel + 1;
 | ||
| 
 | ||
|   for (int n_block = max_active_blocks - 1;
 | ||
|        n_block > 0 && voc_parts > kMaxVocabPartForStage1FastKernel;
 | ||
|        --n_block) {
 | ||
|     int dyn_smem_size = max_smem_per_sm / n_block - extra_smem;
 | ||
|     dyn_smem_size -= dyn_smem_size % sizeof(T);
 | ||
|     voc_parts = ceilDiv(sizeof(T) * params->vocab_size, dyn_smem_size);
 | ||
|   }
 | ||
| 
 | ||
|   if (!params->fuse_softmax || voc_parts > kMaxVocabPartForStage1FastKernel) {
 | ||
|     params->use_fast_kernel = false;
 | ||
|     int sm_count;
 | ||
|     cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
 | ||
|     const int max_act_blocks_per_sm = 4;
 | ||
|     const int max_act_blocks_per_wave = sm_count * max_act_blocks_per_sm;
 | ||
|     const int gridx = params->batch_size * K / 2;
 | ||
|     const int max_part_num = (max_act_blocks_per_wave + gridx - 1) / gridx;
 | ||
|     voc_parts = min(128, max_part_num);
 | ||
|   }
 | ||
|   params->voc_parts = voc_parts;
 | ||
|   params->max_smem_per_block = max_smem_per_block;
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| void DispatchComputeVocParts(BeamSearchParams<T> *params) {
 | ||
|   switch (params->beam_group_size) {
 | ||
|     DISPATCH_COMPUTE_PARTS_K(1);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(2);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(3);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(4);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(5);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(6);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(7);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(8);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(9);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(10);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(11);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(12);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(13);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(14);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(15);
 | ||
|     DISPATCH_COMPUTE_PARTS_K(16);
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| __global__ void update_beam_search_params_kernel(BeamSearchParams<T> params) {
 | ||
|   int bb_id = blockIdx.y;
 | ||
|   int time_step = threadIdx.x + blockIdx.x * blockDim.x;
 | ||
| 
 | ||
|   const bool finished = params.beam_finished[bb_id];
 | ||
|   const int seq_len = params.seq_lens[bb_id];
 | ||
| 
 | ||
|   if (bb_id >= params.beam_width * params.batch_size) {
 | ||
|     return;
 | ||
|   }
 | ||
| 
 | ||
|   if (finished || seq_len < 0) {
 | ||
|     return;
 | ||
|   }
 | ||
| 
 | ||
|   const int beam_group_size = params.beam_group_size;
 | ||
|   const int max_seq_len = params.max_seq_len;
 | ||
|   const int dec_stride = params.dec_stride;
 | ||
| 
 | ||
|   const int batch_group_id = bb_id / beam_group_size;
 | ||
| 
 | ||
|   const int max_dec_len = params.max_dec_lens[bb_id];
 | ||
|   const int src_beam = params.parent_ids[bb_id];
 | ||
|   const int step = params.step_ids[bb_id];
 | ||
| 
 | ||
|   const int *block_tables = params.block_tables;
 | ||
|   int *block_tables_out = params.block_tables_out;
 | ||
|   const int *cache_ids = params.beam_cache_ids;
 | ||
|   int *cache_ids_out = params.cache_ids_out;
 | ||
|   const int *next_tokens = params.next_tokens;
 | ||
| 
 | ||
|   const int beam_group_sub_id = bb_id % beam_group_size;
 | ||
|   // const int src_bb_id = batch_group_id * beam_group_size + src_beam;
 | ||
| 
 | ||
|   if (time_step < min(max_seq_len, seq_len + 1)) {
 | ||
|     const unsigned int block_tables_tgt_offset =
 | ||
|         batch_group_id * beam_group_size * max_seq_len +
 | ||
|         beam_group_sub_id * max_seq_len + time_step;
 | ||
|     const unsigned int block_tables_src_offset =
 | ||
|         batch_group_id * beam_group_size * max_seq_len +
 | ||
|         src_beam * max_seq_len + time_step;
 | ||
|     block_tables_out[block_tables_tgt_offset] =
 | ||
|         block_tables[block_tables_src_offset];
 | ||
|     if (time_step < min(step + 1, max_dec_len)) {
 | ||
|       const unsigned int cache_ids_tgt_offset =
 | ||
|           batch_group_id * beam_group_size * dec_stride +
 | ||
|           beam_group_sub_id * dec_stride + time_step;
 | ||
|       const unsigned int cache_ids_src_offset =
 | ||
|           batch_group_id * beam_group_size * dec_stride +
 | ||
|           src_beam * dec_stride + time_step;
 | ||
|       cache_ids_out[cache_ids_tgt_offset] =
 | ||
|           (time_step == step) ? next_tokens[bb_id]
 | ||
|                               : cache_ids[cache_ids_src_offset];
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| __global__ void update_stop_flags(BeamSearchParams<T> params) {
 | ||
|   int bid = blockIdx.x;
 | ||
|   const int beam_width = params.beam_width;
 | ||
|   const int beam_group_size = params.beam_group_size;
 | ||
|   const bool *beam_finished = params.beam_finished + beam_width * bid;
 | ||
|   bool *stop_flags = params.stop_flags + beam_width * bid;
 | ||
|   bool finished = true;
 | ||
|   if (threadIdx.x == 0 && !stop_flags[0]) {
 | ||
| #pragma unroll
 | ||
|     for (int i = 0; i < beam_width; i += beam_group_size) {
 | ||
|       finished &= beam_finished[i];
 | ||
|     }
 | ||
|     if (finished) {
 | ||
| #pragma unroll
 | ||
|       for (int i = 0; i < beam_width; i++) {
 | ||
|         stop_flags[i] = true;
 | ||
|       }
 | ||
|     }
 | ||
|   }
 | ||
| }
 | ||
| 
 | ||
| template <typename T>
 | ||
| void updateBeamSearchParams(BeamSearchParams<T> *params, cudaStream_t stream) {
 | ||
|   const dim3 block(32);
 | ||
|   const dim3 grid((params->max_seq_len + block.x - 1) / block.x,
 | ||
|                   params->batch_size * params->beam_width);
 | ||
| 
 | ||
|   update_beam_search_params_kernel<<<grid, block, 0, stream>>>(*params);
 | ||
| 
 | ||
|   const dim3 grid_2(params->batch_size);
 | ||
|   update_stop_flags<<<grid_2, 1, 0, stream>>>(*params);
 | ||
| }
 | ||
| 
 | ||
| /*****
 | ||
| In order to adapt to the model structure of 5.2 without
 | ||
| adding while op and without affecting the speed. Use a 'fake inplace' method
 | ||
| here. Not elegant but useful ︸_︸.
 | ||
| *****/
 | ||
| 
 | ||
| std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
 | ||
|                              const paddle::Tensor &seq_lens,
 | ||
|                              const paddle::Tensor &stop_flags,       // inplace
 | ||
|                              const paddle::Tensor &end_ids,
 | ||
|                              const paddle::Tensor &step_ids,
 | ||
|                              const paddle::Tensor &max_dec_lens,
 | ||
|                              const paddle::Tensor &block_tables,     // inplace
 | ||
|                              const paddle::Tensor &cum_scores,       // inplace
 | ||
|                              const paddle::Tensor &beam_cache_ids,   // inplace
 | ||
|                              const paddle::Tensor &beam_hyps,        // inplace
 | ||
|                              const paddle::Tensor &beam_hyps_score,  // inplace
 | ||
|                              const paddle::Tensor &beam_finished,    // inplace
 | ||
|                              const paddle::Tensor &beam_width,
 | ||
|                              const paddle::Tensor &beam_group_num,
 | ||
|                              const paddle::Tensor &length_penalty,
 | ||
|                              const paddle::Tensor &diversity_penalty,
 | ||
|                              bool fuse_softmax,
 | ||
|                              bool early_stop) {
 | ||
|   std::vector<int64_t> logits_shape = logits.shape();
 | ||
|   // logits_shape
 | ||
|   auto cu_stream = logits.stream();
 | ||
|   int beam_width_scalar;
 | ||
|   cudaMemcpyAsync(&beam_width_scalar,
 | ||
|                   beam_width.data<int>(),
 | ||
|                   sizeof(int),
 | ||
|                   cudaMemcpyDeviceToHost,
 | ||
|                   cu_stream);
 | ||
| 
 | ||
|   int beam_group_num_scalar;
 | ||
|   cudaMemcpyAsync(&beam_group_num_scalar,
 | ||
|                   beam_group_num.data<int>(),
 | ||
|                   sizeof(int),
 | ||
|                   cudaMemcpyDeviceToHost,
 | ||
|                   cu_stream);
 | ||
| 
 | ||
|   int beam_batch_size = logits_shape[0];
 | ||
|   int batch_size = beam_batch_size / beam_width_scalar;
 | ||
|   int vocab_size = logits_shape[1];
 | ||
|   const int max_seq_len = block_tables.dims()[1];
 | ||
|   // liuzichang: In some cases, the length of Tensor is longer than max_dec_lens
 | ||
|   const int dec_stride = beam_hyps.dims()[1];
 | ||
|   const int end_ids_len = end_ids.dims()[0];
 | ||
|   const int beam_group_size = beam_width_scalar / beam_group_num_scalar;
 | ||
| 
 | ||
|   auto next_tokens = paddle::full({logits_shape[0], 1}, 0,  end_ids.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   auto parent_ids = paddle::full({logits_shape[0], 1}, 0,  end_ids.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
|   cudaMemcpyAsync(cum_scores_ori.mutable_data<float>(),
 | ||
|                   cum_scores.data<float>(),
 | ||
|                   sizeof(float)*cum_scores.numel(),
 | ||
|                   cudaMemcpyDeviceToDevice,
 | ||
|                   cu_stream);
 | ||
|   cudaMemcpyAsync(beam_cache_ids_ori.mutable_data<int>(),
 | ||
|                   beam_cache_ids.data<int>(),
 | ||
|                   sizeof(int)*beam_cache_ids.numel(),
 | ||
|                   cudaMemcpyDeviceToDevice,
 | ||
|                   cu_stream);
 | ||
|   cudaMemcpyAsync(block_tables_ori.mutable_data<int>(),
 | ||
|                   block_tables.data<int>(),
 | ||
|                   sizeof(int)*block_tables.numel(),
 | ||
|                   cudaMemcpyDeviceToDevice,
 | ||
|                   cu_stream);
 | ||
| 
 | ||
|   const int tmp_size = batch_size * beam_group_size * beam_group_size * 2;
 | ||
| 
 | ||
|   auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   BeamSearchParams<float> params;
 | ||
|   params.batch_size = batch_size;
 | ||
|   params.beam_width = beam_width_scalar;
 | ||
|   params.beam_group_size = beam_group_size;
 | ||
| 
 | ||
|   params.vocab_size = vocab_size;
 | ||
|   params.dec_stride = dec_stride;
 | ||
|   params.max_seq_len = max_seq_len;
 | ||
|   params.end_ids_len = end_ids_len;
 | ||
| 
 | ||
|   params.fuse_softmax = fuse_softmax;
 | ||
|   params.early_stop = early_stop;
 | ||
| 
 | ||
|   // Only Read
 | ||
|   params.step_ids = step_ids.data<int>();
 | ||
|   params.seq_lens = seq_lens.data<int>();
 | ||
|   params.max_dec_lens = max_dec_lens.data<int>();
 | ||
|   params.end_ids = end_ids.data<int>();
 | ||
|   params.length_penalty = length_penalty.data<float>();
 | ||
|   params.diversity_penalty = diversity_penalty.data<float>();
 | ||
| 
 | ||
|   params.cum_scores = cum_scores_ori.data<float>();
 | ||
|   params.block_tables = block_tables_ori.data<int>();
 | ||
|   params.beam_cache_ids = beam_cache_ids_ori.data<int>();
 | ||
| 
 | ||
|   // Write
 | ||
|   params.logits = const_cast<float *>(logits.data<float>());
 | ||
|   params.cache_ids_out = const_cast<int *>(beam_cache_ids.data<int>());
 | ||
|   params.block_tables_out = const_cast<int *>(block_tables.data<int>());
 | ||
|   params.cum_scores_out = const_cast<float *>(cum_scores.data<float>());
 | ||
|   params.beam_hyps_out = const_cast<int *>(beam_hyps.data<int>());
 | ||
|   params.beam_hyps_score_out = const_cast<float *>(beam_hyps_score.data<float>());
 | ||
|   params.beam_finished = const_cast<bool *>(beam_finished.data<bool>());
 | ||
|   params.stop_flags = const_cast<bool *>(stop_flags.data<bool>());
 | ||
| 
 | ||
|   params.next_tokens = const_cast<int *>(next_tokens.data<int>());
 | ||
|   params.parent_ids = const_cast<int *>(parent_ids.data<int>());
 | ||
| 
 | ||
|   params.tmp_ids = tmp_topk_id.data<int>();
 | ||
|   params.tmp_vals = tmp_topk_val.data<float>();
 | ||
| 
 | ||
|   DispatchComputeVocParts<float>(¶ms);
 | ||
|   // allocate workspace
 | ||
|   const int tmp_id_val_size =
 | ||
|       batch_size * beam_group_size * beam_group_size * 2;
 | ||
|   const int packed_top_kmd_size =
 | ||
|       fuse_softmax ? 2 * 2 * beam_group_size + 2 : 2 * 2 * beam_group_size;
 | ||
|   const int tmp_stage1_to_stage2_size =
 | ||
|       batch_size * beam_group_size * params.voc_parts * packed_top_kmd_size;
 | ||
| 
 | ||
|   const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size;
 | ||
| 
 | ||
|   auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(),
 | ||
|                                  paddle::GPUPlace());
 | ||
| 
 | ||
|   params.tmp_ids = reinterpret_cast<int *>(wsp_buffer_tensor.data<float>());
 | ||
|   params.tmp_vals = wsp_buffer_tensor.data<float>() + tmp_id_val_size;
 | ||
|   params.tmp_buffer = wsp_buffer_tensor.data<float>() + 2 * tmp_id_val_size;
 | ||
| 
 | ||
|   for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar;
 | ||
|        ++beam_group_idx) {
 | ||
|     if (beam_group_num_scalar == 1) {
 | ||
|       invokeTopkSoftMax<float, false>(
 | ||
|           ¶ms, beam_group_idx, cu_stream);
 | ||
|     } else {
 | ||
|       invokeTopkSoftMax<float, true>(
 | ||
|           ¶ms, beam_group_idx, cu_stream);
 | ||
|     }
 | ||
|   }
 | ||
|   updateBeamSearchParams<float>(¶ms, cu_stream);
 | ||
|   return {next_tokens, parent_ids};
 | ||
| }
 | ||
| 
 | ||
| std::vector<std::vector<int64_t>> BeamSearchSoftmaxShape(
 | ||
|                              const std::vector<int64_t> &logits,
 | ||
|                              const std::vector<int64_t> &seq_lens,
 | ||
|                              const std::vector<int64_t> &stop_flags,       // inplace
 | ||
|                              const std::vector<int64_t> &end_ids,
 | ||
|                              const std::vector<int64_t> &step_ids,
 | ||
|                              const std::vector<int64_t> &max_dec_lens,
 | ||
|                              const std::vector<int64_t> &block_tables,     // inplace
 | ||
|                              const std::vector<int64_t> &cum_scores,       // inplace
 | ||
|                              const std::vector<int64_t> &beam_cache_ids,   // inplace
 | ||
|                              const std::vector<int64_t> &beam_hyps,        // inplace
 | ||
|                              const std::vector<int64_t> &beam_hyps_score,  // inplace
 | ||
|                              const std::vector<int64_t> &beam_finished,    // inplace
 | ||
|                              const std::vector<int64_t> &beam_width,
 | ||
|                              const std::vector<int64_t> &beam_group_num,
 | ||
|                              const std::vector<int64_t> &length_penalty,
 | ||
|                              const std::vector<int64_t> &diversity_penalty) {
 | ||
|     std::vector<int64_t> next_tokens = {logits[0],1};
 | ||
|     std::vector<int64_t> parent_ids = {logits[0],1};
 | ||
|     return {next_tokens,parent_ids};
 | ||
| }
 | ||
| 
 | ||
| std::vector<paddle::DataType> BeamSearchSoftmaxDtype(
 | ||
|                              const paddle::DataType &logits,
 | ||
|                              const paddle::DataType &seq_lens,
 | ||
|                              const paddle::DataType &stop_flags,       // inplace
 | ||
|                              const paddle::DataType &end_ids,
 | ||
|                              const paddle::DataType &step_ids,
 | ||
|                              const paddle::DataType &max_dec_lens,
 | ||
|                              const paddle::DataType &block_tables,     // inplace
 | ||
|                              const paddle::DataType &cum_scores,       // inplace
 | ||
|                              const paddle::DataType &beam_cache_ids,   // inplace
 | ||
|                              const paddle::DataType &beam_hyps,        // inplace
 | ||
|                              const paddle::DataType &beam_hyps_score,  // inplace
 | ||
|                              const paddle::DataType &beam_finished,    // inplace
 | ||
|                              const paddle::DataType &beam_width,
 | ||
|                              const paddle::DataType &beam_group_num,
 | ||
|                              const paddle::DataType &length_penalty,
 | ||
|                              const paddle::DataType &diversity_penalty) {
 | ||
|     return {paddle::DataType::INT32, paddle::DataType::INT32};
 | ||
| }
 | ||
| 
 | ||
| PD_BUILD_STATIC_OP(beam_search_softmax)
 | ||
|     .Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables"
 | ||
|     , "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished"
 | ||
|     , "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"})
 | ||
|     .Outputs({"next_tokens", "parent_ids"})
 | ||
|     .Attrs({"fuse_softmax: bool",
 | ||
|             "early_stop: bool"})
 | ||
|     .SetKernelFn(PD_KERNEL(BeamSearchSoftmax))
 | ||
|     .SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape))
 | ||
|     .SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype));
 | 
