Files
FastDeploy/custom_ops/gpu_ops/beam_search_softmax.cu
2025-06-09 19:20:15 +08:00

1546 lines
56 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <algorithm>
#include "stdint.h"
#include "helper.h"
#define FLT_MAX 1e38
static constexpr int kBlockSizeForSmallBeamWidth = 256;
static constexpr int kMaxVocabPartForStage1FastKernel = 128;
#define CASE_K(K) \
case K: \
invokeTopKSoftMaxLauncher<T, 2 * K, GROUP>( \
params, beam_group_idx, stream); \
break
#define DISPATCH_COMPUTE_PARTS_K(K) \
case K: \
ComputeVocParts<T, 2 * K>(params); \
break
template <typename T>
struct BeamSearchParams {
// Scalar values
int batch_size{0};
int beam_width{0};
int beam_group_size{0};
int beam_group_idx{0};
int vocab_size{0};
int dec_stride{0};
int max_seq_len{0};
int end_ids_len{0};
bool fuse_softmax{true};
bool early_stop{false};
int voc_parts{0};
bool use_fast_kernel{true};
int max_smem_per_block{0};
T *logits{nullptr};
const int *step_ids{nullptr}; // [BS * BM, 1]
const int *seq_lens{nullptr}; // [BS * BM, 1]
const int *max_dec_lens{nullptr};
const int *end_ids{nullptr};
const T *cum_scores{nullptr};
const int *block_tables{nullptr};
const int *beam_cache_ids{nullptr};
const float *length_penalty{nullptr}; // [BS, 1]
const float *diversity_penalty{nullptr}; // [BS, 1]
bool *stop_flags{nullptr}; // [BS, 1]
int *cache_ids_out{nullptr}; // [BS * BM, max_dec_len]
bool *beam_finished{nullptr}; // [BS * BM, 1]
int *block_tables_out{nullptr}; // [BS * BM, max_seq_len]
T *cum_scores_out{nullptr}; // [BS * BM, 1]
int *beam_hyps_out{nullptr}; // [BS * BM, max_dec_len]
T *beam_hyps_score_out{nullptr}; // [BS * BM, 1]
// func out
int *next_tokens{nullptr};
int *parent_ids{nullptr};
// workspace
int *tmp_ids{nullptr};
T *tmp_vals{nullptr};
T *tmp_buffer{nullptr};
};
template <typename T,
typename U,
typename = std::enable_if_t<std::is_integral<T>::value>,
typename = std::enable_if_t<std::is_integral<U>::value>>
auto constexpr ceilDiv(T numerator, U denominator) {
return (numerator + denominator - 1) / denominator;
}
__device__ bool is_in_end(const int id, const int *end_ids, int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
if (id == end_ids[i]) {
return true;
}
}
return flag;
}
template <typename T>
__device__ __forceinline__ T apply_length_penalty(T log_prob,
int length,
float length_penalty) {
// score = log(prob) / (length)^length_penalty.
if (length_penalty == 0.0f || length == 1) {
return log_prob;
}
return log_prob / static_cast<T>(powf(length, length_penalty));
}
// <<<batch_size, beam_group_size>>>
template <typename T, int K>
__global__ void apply_group_diversity_penalty(BeamSearchParams<T> params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int vocab_size) {
const int beam_group_size = K / 2;
const int batch_idx = blockIdx.x;
const int beam_group_sub_idx = threadIdx.x;
const bool *beam_finished = params.beam_finished + batch_idx * beam_width;
T *logtis = params.logits + batch_idx * beam_width * vocab_size +
beam_group_idx * beam_group_size * vocab_size +
beam_group_sub_idx * vocab_size;
int *next_tokens = params.next_tokens + batch_idx * beam_width;
// apply previous group token ids penalty
#pragma unroll
for (int token_idx = 0; token_idx < beam_group_idx * beam_group_size;
++token_idx) {
const bool finished = beam_finished[token_idx];
if (!finished) {
const int token_id = next_tokens[token_idx];
logtis[token_id] -= params.diversity_penalty[batch_idx];
}
}
}
struct DySoftMaxStruct {
float logit;
float score;
};
__device__ __forceinline__ DySoftMaxStruct
reduce_softmax_op(DySoftMaxStruct a, DySoftMaxStruct b) {
bool a_bigger = (a.logit > b.logit);
DySoftMaxStruct bigger_m = a_bigger ? a : b;
DySoftMaxStruct smaller_m = a_bigger ? b : a;
DySoftMaxStruct res;
res.score =
bigger_m.score + smaller_m.score * expf(smaller_m.logit - bigger_m.logit);
res.logit = bigger_m.logit;
return res;
}
template <typename T>
struct BeamHypothesis {
T score;
int *seq;
int seq_len;
__device__ __forceinline__ void init(int *_seq,
T _score,
const int _max_seq_len) {
seq = _seq;
score = _score;
seq_len = _max_seq_len;
}
};
template <typename T, int K>
struct BeamHypothesesTopK {
BeamHypothesis<T> hyps[K];
int max_dec_len;
__device__ __forceinline__ void init(int *_beam_hyps,
T *_beam_hyps_score,
const int _max_dec_len) {
max_dec_len = _max_dec_len;
for (int i = 0; i < K; i++) {
// 使用默认构造函数创建默认的 BeamHypothesis 对象
hyps[i].init(
_beam_hyps + i * _max_dec_len, _beam_hyps_score[i], _max_dec_len);
}
}
__device__ void insert(const int *token_ids,
int step,
int cur_token_id,
T score) {
if (score > get_worst_score()) {
for (int i = 0; i < step; i++) {
hyps[K - 1].seq[i] = token_ids[i];
}
hyps[K - 1].seq[step] = cur_token_id;
hyps[K - 1].score = score;
for (int k = K - 2; k >= 0; --k) {
if (hyps[k + 1].score > hyps[k].score) {
T tmp_score = hyps[k].score;
hyps[k].score = hyps[k + 1].score;
hyps[k + 1].score = tmp_score;
int tmp_val;
for (int i = 0;
i <= step && (hyps[k + 1].seq[i] > 0 || hyps[k].seq[i] > 0);
i++) {
tmp_val = hyps[k + 1].seq[i];
hyps[k + 1].seq[i] = hyps[k].seq[i];
hyps[k].seq[i] = tmp_val;
}
}
}
}
}
__device__ __forceinline__ T get_worst_score() { return hyps[K - 1].score; }
};
template <typename T, int K>
struct TopK {
int ids[K];
T vals[K];
int parent_ids[K];
__device__ __forceinline__ void insert(T elem, int elem_id) {
if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
vals[K - 1] = elem;
ids[K - 1] = elem_id;
}
for (int k = K - 2; k >= 0; --k) {
if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
T tmp_val = vals[k];
int tmp_id = ids[k];
vals[k] = vals[k + 1];
ids[k] = ids[k + 1];
vals[k + 1] = tmp_val;
ids[k + 1] = tmp_id;
}
}
}
__device__ __forceinline__ void insert(T elem, int elem_id, int parent_id) {
if (elem > vals[K - 1] || (ids[K - 1] == -1) ||
((elem == vals[K - 1]) && (elem_id < ids[K - 1]))) {
vals[K - 1] = elem;
ids[K - 1] = elem_id;
parent_ids[K - 1] = parent_id;
}
for (int k = K - 2; k >= 0; --k) {
if ((vals[k + 1] > vals[k]) || (ids[k] == -1) ||
((vals[k + 1] == vals[k]) && (ids[k + 1] < ids[k]))) {
T tmp_val = vals[k];
int tmp_id = ids[k];
int parent_id2 = parent_ids[k];
vals[k] = vals[k + 1];
ids[k] = ids[k + 1];
parent_ids[k] = parent_ids[k + 1];
vals[k + 1] = tmp_val;
ids[k + 1] = tmp_id;
parent_ids[k + 1] = parent_id2;
}
}
}
};
template <typename T, int K>
__device__ __forceinline__ TopK<T, K> reduce_topk_op(const TopK<T, K> &a,
const TopK<T, K> &b) {
TopK<T, K> res = a;
for (int i = 0; i < K; ++i) res.insert(b.vals[i], b.ids[i]);
return res;
}
template <typename T, int K>
struct TopKSoftMax {
DySoftMaxStruct softmax_md;
TopK<T, K> topk;
};
template <typename T, int K>
__device__ __forceinline__ TopKSoftMax<T, K> reduce_topk_softmax_op(
const TopKSoftMax<T, K> &a, const TopKSoftMax<T, K> &b) {
TopKSoftMax<T, K> res;
// max_logit in block
res.softmax_md = reduce_softmax_op(a.softmax_md, b.softmax_md);
res.topk = reduce_topk_op(a.topk, b.topk);
return res;
}
struct __align__(8) MD {
float m;
float d;
};
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
bool const isABigger = a.m > b.m;
MD const bigger = isABigger ? a : b;
MD const smaller = isABigger ? b : a;
MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)};
return res;
}
template <typename T, int K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__
void beam_search_softmax_topk_stage1_fast(const T *logits,
float *tmp_buffer,
const int *end_ids,
const bool *beam_finished,
const int *seq_lens,
int beam_width,
int beam_group_idx,
int vocab_size,
int vocab_chunk_size) {
constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
const int beam_group_size = K / 2;
const int tid = threadIdx.x;
const int group_beam_batch_id = blockIdx.x;
const int batch_id = group_beam_batch_id / beam_group_size;
const int beam_group_sub_id = group_beam_batch_id % beam_group_size;
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_id;
const int seq_len = seq_lens[beam_batch_id];
const bool finished = beam_finished[beam_batch_id];
if (seq_len < 0 || finished) {
return;
}
const int section_start = vocab_chunk_size * blockIdx.y;
const int section_end =
std::min(section_start + vocab_chunk_size, vocab_size);
const int valid_smem_length = section_end - section_start;
T const MAX_T_VAL = 1e38;
// Load element from logits to smemLogProbs, doing reduce_md and argmax
// meanwhile Each thread is responsible for `vocab_chunk_size /
// THREADBLOCK_SIZE` elements
extern __shared__ char smem[];
T *smemLogProbs = reinterpret_cast<T *>(smem);
MD partial_md{-MAX_T_VAL, 0.0f};
using KVPair = cub::KeyValuePair<int, T>;
KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
cub::ArgMax argmax;
T const *local_logits = logits + beam_batch_id * vocab_size;
#pragma unroll 1
for (int i = section_start + tid; i < section_end; i += THREADBLOCK_SIZE) {
T const val = local_logits[i];
const int smem_index = i - section_start;
smemLogProbs[smem_index] = val;
MD new_elem_md{val, 1.0F};
partial_md = reduce_md_op(partial_md, new_elem_md);
KVPair new_elem_topk{smem_index, val};
topKVPairPartial = argmax(topKVPairPartial, new_elem_topk);
}
__syncthreads();
// Search the top 2K elements among `vocab_chunk_size` elements of this
// ThreadBlock and write into smemOutput
__shared__ float smemOutput[PACKED_TOP_KMD_SIZE];
__shared__ int threadToUpdate;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
__shared__ union {
typename BlockReduceTopK::TempStorage topk;
typename BlockReduceMD::TempStorage md;
} smemReduceBuffer;
for (int i = 0; i < 2 * beam_group_size; ++i) {
// Pop the element with largest value to "smemOutput" per iteration
KVPair topKVPair =
BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
if (tid == 0) {
// const int index = beam_batch_id * vocab_size + section_start +
const int index = section_start + topKVPair.key;
reinterpret_cast<int *>(smemOutput)[i] = index;
smemOutput[K + i] = topKVPair.value;
smemLogProbs[topKVPair.key] =
-MAX_T_VAL; // pollute the value of the popped element
threadToUpdate = topKVPair.key % THREADBLOCK_SIZE;
}
__syncthreads();
if (tid == threadToUpdate && i < 2 * beam_group_size - 1) {
// The thread popped the element need to update its topKVPairPartial
// No need to do this in the last iteration
topKVPairPartial.key = vocab_size - 1;
topKVPairPartial.value = -MAX_T_VAL;
for (int index = tid; index < valid_smem_length;
index += THREADBLOCK_SIZE) {
topKVPairPartial =
argmax(topKVPairPartial, {index, smemLogProbs[index]});
}
}
}
// Do reduce_md among the top 2K elements in the smemOutput and write into
// tail of smemOutput
auto reduce_md_func = [](const MD &a, const MD &b) {
return reduce_md_op(a, b);
};
MD total_md =
BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
if (tid == 0) {
smemOutput[2 * K] = total_md.d;
smemOutput[2 * K + 1] = total_md.m;
}
__syncthreads();
// Write the smemOutput into tmp_buffer
float *local_temp_buffer =
tmp_buffer + group_beam_batch_id * PACKED_TOP_KMD_SIZE * gridDim.y +
blockIdx.y * PACKED_TOP_KMD_SIZE;
#pragma unroll
for (int i = tid; i < PACKED_TOP_KMD_SIZE; i += THREADBLOCK_SIZE) {
local_temp_buffer[i] = smemOutput[i];
}
}
// <<<(batch_size * beam_group_size, voc_parts), 128>>>
template <typename T, int K, int THREADBLOCK_SIZE, int PACKED_TOP_KMD_SIZE>
__global__ void beam_search_softmax_topk_stage1(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const bool fuse_softmax) {
const int thread_id = threadIdx.x;
const int beam_group_size = K / 2;
const int batch_id = blockIdx.x / beam_group_size;
const int beam_group_sub_idx = blockIdx.x % beam_group_size;
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_idx;
const bool finish = params.beam_finished[beam_batch_id];
const int seq_len = params.seq_lens[beam_batch_id];
// for dybatch
if (seq_len < 0 || finish) {
return;
}
// 2 * K + 2
__shared__ float buf_s[PACKED_TOP_KMD_SIZE];
const T MAX_T_VAL = FLT_MAX;
const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y;
const int section_start = v_local * blockIdx.y;
int section_end = section_start + v_local;
section_end = (section_end > vocab_size) ? vocab_size : section_end;
T *logits = params.logits + beam_batch_id * vocab_size;
if (fuse_softmax) {
typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopKSoftMax<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.topk.ids[i] = -1;
partial.topk.vals[i] = -MAX_T_VAL;
}
partial.softmax_md.logit = -MAX_T_VAL;
partial.softmax_md.score = 0.0F;
// process voc_parts
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end;
elem_id += THREADBLOCK_SIZE) {
T elem = logits[elem_id];
DySoftMaxStruct new_elem{elem, 1.0F};
partial.softmax_md = reduce_softmax_op(partial.softmax_md, new_elem);
partial.topk.insert(elem, elem_id);
}
// === old_beam_search strategy ===
// }
// reduce voc_parts
TopKSoftMax<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
if (thread_id == 0) {
for (int i = 0; i < K; i++) {
reinterpret_cast<int *>(buf_s)[i] = total.topk.ids[i];
buf_s[K + i] = total.topk.vals[i];
}
buf_s[2 * K] = total.softmax_md.score;
buf_s[2 * K + 1] = total.softmax_md.logit;
}
} else {
typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -MAX_T_VAL;
}
#pragma unroll 1
for (int elem_id = section_start + thread_id; elem_id < section_end;
elem_id += THREADBLOCK_SIZE) {
T elem = logits[elem_id];
partial.insert(elem, elem_id);
}
TopK<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
if (thread_id == 0) {
for (int i = 0; i < K; i++) {
reinterpret_cast<int *>(buf_s)[i] = total.ids[i];
buf_s[K + i] = total.vals[i];
}
}
}
__syncthreads();
// write all the voc_parts results to tmp_buffer
for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE;
elem_id += THREADBLOCK_SIZE) {
params.tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y +
blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] =
buf_s[elem_id];
}
}
template <typename T, int K, int THREADBLOCK_SIZE, bool IS_FAST_KERNEL>
__launch_bounds__(THREADBLOCK_SIZE) __global__
void beam_search_softmax_topk_stage2_fast(
int *__restrict tmp_ids,
T *__restrict tmp_vals,
float *__restrict tmp_buffer,
const float *__restrict cum_scores,
const bool *__restrict beam_finished,
const int *__restrict seq_lens,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const int voc_parts) {
constexpr int PACKED_TOP_KMD_SIZE = 2 * K + 2;
constexpr int beam_group_size = K / 2;
const int group_beam_batch_id = blockIdx.x;
const int beam_group_sub_id = blockIdx.x % beam_group_size;
const int batch_size = group_beam_batch_id / beam_group_size;
const int beam_batch_id = batch_size * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_id;
if (seq_lens[beam_batch_id] < 0 || beam_finished[beam_batch_id]) {
return;
}
const int tid = threadIdx.x;
T const MAX_T_VAL = FLT_MAX;
using KVPair = cub::KeyValuePair<int, T>;
using BlockReduceTopK = cub::BlockReduce<KVPair, THREADBLOCK_SIZE>;
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
__shared__ KVPair buf_smem_kv[K];
__shared__ union {
typename BlockReduceTopK::TempStorage topk;
typename BlockReduceMD::TempStorage md;
} smemReduceBuffer;
cub::ArgMax argmax;
MD partial_md{-MAX_T_VAL, 0.0f};
KVPair topKVPair{vocab_size - 1, -MAX_T_VAL};
auto reduce_md_func = [](const MD &a, const MD &b) {
return reduce_md_op(a, b);
};
// Load and unpack into registers through smem
float *localTempBuffer =
tmp_buffer + PACKED_TOP_KMD_SIZE * group_beam_batch_id * voc_parts;
if constexpr (IS_FAST_KERNEL) { // Use share memory instead of global memory
extern __shared__ char smem[];
float *smemVal = reinterpret_cast<float *>(smem);
for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
smemVal[idx] = localTempBuffer[idx];
}
localTempBuffer = smemVal;
__syncthreads();
}
// Find the top 2K across all voc_parts
for (int k = 0; k < K; ++k) {
KVPair topKVPairPartial{vocab_size - 1, -MAX_T_VAL};
// Only threads responsible for a chunk will do the computation
if (tid < voc_parts) {
for (int i = 0; i < K; ++i) {
const int current_index = tid * PACKED_TOP_KMD_SIZE + i;
T topValue = localTempBuffer[current_index + K];
topKVPairPartial = argmax(topKVPairPartial, {current_index, topValue});
}
}
KVPair topKVPair =
BlockReduceTopK(smemReduceBuffer.topk).Reduce(topKVPairPartial, argmax);
__syncthreads();
if (tid == 0) {
// Store kv pairs in shared mem buffer
int temp_offset = topKVPair.key;
int global_offset = reinterpret_cast<int *>(localTempBuffer)[temp_offset];
topKVPair.key = global_offset;
buf_smem_kv[k] = topKVPair;
// Invalidate the maximum value within the chunk
reinterpret_cast<int *>(localTempBuffer)[temp_offset] =
vocab_size - 1; // id in share memory
localTempBuffer[temp_offset + K] = -MAX_T_VAL; // value in share memory
}
__syncthreads();
}
// Extract and reduce MD values across the chunks
if (tid < voc_parts) {
partial_md.d = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K];
partial_md.m = localTempBuffer[tid * PACKED_TOP_KMD_SIZE + 2 * K + 1];
}
__syncthreads();
MD total_md =
BlockReduceMD(smemReduceBuffer.md).Reduce(partial_md, reduce_md_func);
if (tid == 0) {
float d_total_log = logf(total_md.d);
for (int i = 0; i < K; ++i) {
float val =
static_cast<float>(buf_smem_kv[i].value) - total_md.m - d_total_log;
tmp_ids[group_beam_batch_id * K + i] = buf_smem_kv[i].key;
tmp_vals[group_beam_batch_id * K + i] = val + cum_scores[beam_batch_id];
}
}
}
#define BEAM_STAGE2_KERNEL(N_VOCAB_PART, IS_FAST_KERNEL) \
do { \
if (IS_FAST_KERNEL && nShareMemory >= (48 << 10)) { \
cudaFuncSetAttribute( \
beam_search_softmax_topk_stage2_fast<T, \
K, \
N_VOCAB_PART, \
IS_FAST_KERNEL>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
nShareMemory); \
} \
beam_search_softmax_topk_stage2_fast<T, K, N_VOCAB_PART, IS_FAST_KERNEL> \
<<<batch_size * beam_group_size, \
N_VOCAB_PART, \
IS_FAST_KERNEL * nShareMemory, \
stream>>>(params->tmp_ids, \
params->tmp_vals, \
params->tmp_buffer, \
params->cum_scores, \
params->beam_finished, \
params->seq_lens, \
beam_width, \
beam_group_idx, \
vocab_size, \
voc_parts); \
} while (0); \
return;
template <typename T, int K>
__inline__ void beamSearchSoftmaxTopkStage2FastKernelLauncher(
BeamSearchParams<T> *params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int vocab_size,
const int voc_parts,
const int max_smem_per_block,
cudaStream_t stream) {
constexpr int beam_group_size = K / 2;
size_t const nShareMemory = sizeof(float) * voc_parts * (2 * K + 2) +
sizeof(cub::KeyValuePair<int, T>) * K;
if (nShareMemory < max_smem_per_block) { // IS_FAST_KERNEL must be a
// compilation-time constant
if (voc_parts <= 32) {
BEAM_STAGE2_KERNEL(32, true)
}
if (voc_parts <= 64) {
BEAM_STAGE2_KERNEL(64, true)
}
BEAM_STAGE2_KERNEL(128, true)
// No larger branch since voc_parts <= nMaxVocabPartForStage1FastKernel
}
BEAM_STAGE2_KERNEL(128, false)
}
template <typename T, int K, int THREADBLOCK_SIZE>
__global__ void beam_search_softmax_topk_stage2(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax) {
const int thread_id = threadIdx.x;
const int beam_group_size = K / 2;
const int batch_id = blockIdx.x / beam_group_size;
const int beam_group_sub_idx = blockIdx.x % beam_group_size;
// int vector_id = blockIdx.x; // batch beam index.
const int beam_batch_id = batch_id * beam_width +
beam_group_idx * beam_group_size +
beam_group_sub_idx;
const int group_beam_batch_id = blockIdx.x;
// const int vector_id = blockIdx.x;
const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size;
// for dybatch
const int seq_len = params.seq_lens[beam_batch_id];
const bool finish = params.beam_finished[beam_batch_id];
int *tmp_ids = params.tmp_ids + group_beam_batch_id * K;
float *tmp_vals = params.tmp_vals + group_beam_batch_id * K;
float *tmp_buffer = params.tmp_buffer;
const T *cum_scores = params.cum_scores + beam_batch_id;
if (seq_len < 0 || finish) {
return;
}
const T MAX_T_VAL = FLT_MAX;
extern __shared__ char buf_s_[];
float *buf_s = reinterpret_cast<float *>(buf_s_);
// 当前 batch beam 的所有 voc
tmp_buffer += group_beam_batch_id * PACKED_TOP_KMD_SIZE * voc_parts;
if (fuse_softmax) {
typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopKSoftMax<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.topk.ids[i] = -1;
partial.topk.vals[i] = -MAX_T_VAL;
}
partial.softmax_md.logit = -MAX_T_VAL;
partial.softmax_md.score = 0.0F;
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
buf_s[idx] = tmp_buffer[idx];
}
__syncthreads();
if (threadIdx.x < voc_parts) {
float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; i++) {
partial.topk.ids[i] = reinterpret_cast<int *>(b_s)[i];
partial.topk.vals[i] = b_s[K + i];
}
partial.softmax_md.score = b_s[2 * K];
partial.softmax_md.logit = b_s[2 * K + 1];
}
__syncthreads();
TopKSoftMax<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_softmax_op<T, K>);
if (thread_id == 0) {
// tmp_ids += group_beam_batch_id * K;
// tmp_vals += group_beam_batch_id * K;
float d_total_log = logf(total.softmax_md.score);
for (int i = 0; i < K; ++i) {
// float val = expf((float)total.topk.vals[i] - total.softmax_md.logit -
// d_total_log);
float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log;
tmp_ids[i] = total.topk.ids[i];
tmp_vals[i] = val + params.cum_scores[beam_batch_id];
}
}
} else {
typedef cub::BlockReduce<TopK<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
TopK<T, K> partial;
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -MAX_T_VAL;
}
for (int idx = thread_id; idx < PACKED_TOP_KMD_SIZE * voc_parts;
idx += THREADBLOCK_SIZE) {
buf_s[idx] = tmp_buffer[idx];
}
__syncthreads();
if (threadIdx.x < voc_parts) {
float *b_s = buf_s + thread_id * PACKED_TOP_KMD_SIZE;
for (int i = 0; i < K; i++) {
partial.ids[i] = reinterpret_cast<int *>(b_s)[i];
partial.vals[i] = b_s[K + i];
}
}
__syncthreads();
TopK<T, K> total =
BlockReduce(temp_storage).Reduce(partial, reduce_topk_op<T, K>);
if (thread_id == 0) {
tmp_ids += group_beam_batch_id * K;
tmp_vals += group_beam_batch_id * K;
for (int i = 0; i < K; ++i) {
float val = total.vals[i];
tmp_ids[i] = total.ids[i];
tmp_vals[i] = val + params.cum_scores[beam_batch_id];
}
}
}
}
template <typename T, int K>
void invokeBeamSearchSoftmaxTopKStage2(BeamSearchParams<T> *params,
const int batch_size,
const int beam_width,
const int beam_group_idx,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax,
gpuStream_t stream) {
int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float);
const int beam_group_size = K / 2;
if (voc_parts <= 32) {
beam_search_softmax_topk_stage2<T, K, 32>
<<<batch_size * beam_group_size, 32, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 64) {
beam_search_softmax_topk_stage2<T, K, 64>
<<<batch_size * beam_group_size, 64, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 128) {
beam_search_softmax_topk_stage2<T, K, 128>
<<<batch_size * beam_group_size, 128, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
if (voc_parts <= 256) {
beam_search_softmax_topk_stage2<T, K, 256>
<<<batch_size * beam_group_size, 256, smem_stage2_size, stream>>>(
*params,
beam_width,
beam_group_idx,
voc_parts,
packed_top_kmd_size,
fuse_softmax);
return;
}
}
template <typename T, int K>
__global__ void update_beam_finished_early_stop(const T *beam_hyps_score_out,
bool *beam_finished) {
if (threadIdx.x == 0) {
int batch_idx = blockIdx.x;
const T *cur_beam_hyps_score = beam_hyps_score_out + batch_idx * K;
bool *cur_beam_finished = beam_finished + batch_idx * K;
if (cur_beam_hyps_score[K - 1] > -1e8) {
for (int i = 0; i < K; i++) {
cur_beam_finished[i] = true;
}
}
}
}
// <<<batch_size>>>
template <typename T, int K, int THREADBLOCK_SIZE, bool GROUP>
__global__ void batch_topk(BeamSearchParams<T> params,
const int beam_width,
const int beam_group_idx,
const int dec_stride) {
const bool early_stop = params.early_stop;
const int thread_id = threadIdx.x;
const int batch_id = blockIdx.x;
// int block_id = blockIdx.x; // bs
const int beam_group_size = K / 2;
const int beam_group_start_id =
batch_id * beam_width + beam_group_idx * beam_group_size;
bool *beam_finished = params.beam_finished + beam_group_start_id;
const int *step_ids = params.step_ids + beam_group_start_id;
int *next_tokens = params.next_tokens + beam_group_start_id;
float *cum_scores_out = params.cum_scores_out + beam_group_start_id;
int *parent_ids = params.parent_ids + beam_group_start_id;
float *beam_hyps_score_out = params.beam_hyps_score_out + beam_group_start_id;
const bool finish = beam_finished[0];
const int step_id = step_ids[0];
const int seq_len = params.seq_lens[beam_group_start_id];
const int max_dec_len = params.max_dec_lens[beam_group_start_id];
const bool last_dec_step = (step_id + 1 == max_dec_len);
if (thread_id == 0 && seq_len > 0 && !finish) {
TopK<T, K> partial;
BeamHypothesesTopK<T, K / 2> beam_hyps;
beam_hyps.init(params.beam_hyps_out + beam_group_start_id * dec_stride,
params.beam_hyps_score_out + beam_group_start_id,
dec_stride);
for (int i = 0; i < K; ++i) {
partial.ids[i] = -1;
partial.vals[i] = -FLT_MAX;
partial.parent_ids[i] = -1;
}
int index = batch_id * beam_group_size * K;
if (step_id == 0) {
for (int i = 0; i < K; i++) {
float score_now = apply_length_penalty(params.tmp_vals[index + i],
step_id + 1,
params.length_penalty[batch_id]);
if (!GROUP) {
score_now -=
params.diversity_penalty[batch_id] * static_cast<float>(i + 1);
}
partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
}
} else {
for (int i = 0; i < beam_group_size * K; i++) {
float score_now = apply_length_penalty(params.tmp_vals[index + i],
step_id + 1,
params.length_penalty[batch_id]);
if (!GROUP) {
score_now -= params.diversity_penalty[batch_id] *
static_cast<float>(i % K + 1);
}
partial.insert((T)score_now, params.tmp_ids[index + i], i / K);
}
}
if (partial.vals[0] < beam_hyps.hyps[beam_group_size - 1].score) {
for (int i = 0; i < beam_group_size; i++) {
beam_finished[i] = true;
}
return;
}
int next_step_num = 0;
for (int i = 0; i < K && next_step_num < beam_group_size; i++) {
int parent_id = partial.parent_ids[i];
if (is_in_end(partial.ids[i], params.end_ids, params.end_ids_len) ||
last_dec_step) {
if (i < beam_group_size &&
partial.vals[i] > beam_hyps.get_worst_score()) {
const int *beam_cache_id = params.beam_cache_ids +
beam_group_start_id * dec_stride +
parent_id * dec_stride;
beam_hyps.insert(beam_cache_id,
step_id,
last_dec_step ? params.end_ids[0] : partial.ids[i],
partial.vals[i]);
}
if (early_stop && beam_hyps.get_worst_score() > -1e8) {
// stop
for (int i = 0; i < beam_group_size; i++) {
beam_finished[i] = true;
}
return;
}
} else {
next_tokens[next_step_num] = partial.ids[i];
cum_scores_out[next_step_num] = partial.vals[i];
parent_ids[next_step_num] = parent_id;
next_step_num += 1;
}
}
for (int i = 0; i < beam_group_size; i++) {
beam_hyps_score_out[i] = beam_hyps.hyps[i].score;
}
if (last_dec_step) {
for (int i = 0; i < beam_group_size; i++) {
beam_finished[i] = true;
}
}
} // if (thread_id == 0)
}
template <typename T, int K, bool GROUP>
void invokeTopKSoftMaxLauncher(BeamSearchParams<T> *params,
int beam_group_idx,
gpuStream_t stream) {
const int batch_size = params->batch_size;
const int beam_width = params->beam_width;
const int beam_group_size = K / 2;
const int vocab_size = params->vocab_size;
const bool fuse_softmax = params->fuse_softmax;
const int voc_parts = params->voc_parts;
constexpr int dev_id = 0;
// only in group_beam_search
if (beam_width > beam_group_size && beam_group_idx != 0) {
apply_group_diversity_penalty<T, K>
<<<batch_size, beam_group_size, 0, stream>>>(
*params, batch_size, beam_width, beam_group_idx, vocab_size);
}
// == Step1 == : stage1
if (params->use_fast_kernel) {
constexpr int block_size =
(K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64;
const int vocab_chunk_size = (vocab_size + voc_parts - 1) / voc_parts;
const int dyn_smem_size = sizeof(T) * vocab_chunk_size;
if (dyn_smem_size >= (48 << 10)) {
cudaFuncSetAttribute(
beam_search_softmax_topk_stage1_fast<T, K, block_size>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
dyn_smem_size);
}
dim3 grid(batch_size * beam_group_size, voc_parts);
beam_search_softmax_topk_stage1_fast<T, K, block_size>
<<<grid, block_size, dyn_smem_size, stream>>>(params->logits,
params->tmp_buffer,
params->end_ids,
params->beam_finished,
params->seq_lens,
beam_width,
beam_group_idx,
vocab_size,
vocab_chunk_size);
} else {
constexpr int block_size = 128;
dim3 grid(batch_size * beam_group_size, voc_parts);
cudaFuncSetAttribute(
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1);
if (fuse_softmax) {
#ifdef PADDLE_WITH_CUDA
cudaFuncSetAttribute(
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1);
#else
// cudaSharedmemCarveoutMaxL1 equal to 0
hipFuncSetAttribute(
reinterpret_cast<void *>(
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>),
hipFuncAttributePreferredSharedMemoryCarveout,
0);
#endif
// bs, bm, voc_parts, 2 * K + 2
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K + 2>
<<<grid, block_size, 0, stream>>>(
*params, beam_width, beam_group_idx, vocab_size, fuse_softmax);
} else {
#ifdef PADDLE_WITH_CUDA
cudaFuncSetAttribute(
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>,
cudaFuncAttributePreferredSharedMemoryCarveout,
cudaSharedmemCarveoutMaxL1);
#else
// cudaSharedmemCarveoutMaxL1 equal to 0
hipFuncSetAttribute(
reinterpret_cast<void *>(
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>),
hipFuncAttributePreferredSharedMemoryCarveout,
0);
#endif
// bs, bm, voc_parts, 2 * K
beam_search_softmax_topk_stage1<float, K, block_size, 2 * K>
<<<grid, block_size, 0, stream>>>(
*params, beam_width, beam_group_idx, vocab_size, fuse_softmax);
}
}
beamSearchSoftmaxTopkStage2FastKernelLauncher<float, K>(
params,
batch_size,
beam_width,
beam_group_idx,
vocab_size,
voc_parts,
params->max_smem_per_block,
stream);
batch_topk<T, K, 32, GROUP><<<batch_size, 32, 0, stream>>>(
*params, beam_width, beam_group_idx, params->dec_stride);
}
template <typename T, bool GROUP>
void invokeTopkSoftMax(BeamSearchParams<T> *params,
int beam_group_idx,
gpuStream_t stream) {
switch (params->beam_group_size) {
CASE_K(1);
CASE_K(2);
CASE_K(3);
CASE_K(4);
CASE_K(5);
CASE_K(6);
CASE_K(7);
CASE_K(8);
CASE_K(9);
CASE_K(10);
CASE_K(11);
CASE_K(12);
CASE_K(13);
CASE_K(14);
CASE_K(15);
CASE_K(16);
}
}
template <typename T, int K>
void ComputeVocParts(BeamSearchParams<T> *params) {
int dev_id = 0;
const int block_size =
(K < 16) ? ((K < 8) ? kBlockSizeForSmallBeamWidth : 128) : 64;
int max_active_blocks = -1;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
beam_search_softmax_topk_stage1_fast<float, K, block_size>,
block_size,
0);
int max_smem_per_sm = -1;
int max_smem_per_block = -1;
cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id);
cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_id);
cudaFuncAttributes attr;
cudaFuncGetAttributes(
&attr, beam_search_softmax_topk_stage1_fast<float, K, block_size>);
const int static_smem = attr.sharedSizeBytes;
const int max_dyn_smem_per_block = max_smem_per_block - static_smem;
if (sizeof(T) * params->vocab_size >
max_dyn_smem_per_block * kMaxVocabPartForStage1FastKernel) {
}
const int driver_smem_per_block = max_smem_per_sm - max_smem_per_block;
const int extra_smem = driver_smem_per_block + static_smem;
int voc_parts = kMaxVocabPartForStage1FastKernel + 1;
for (int n_block = max_active_blocks - 1;
n_block > 0 && voc_parts > kMaxVocabPartForStage1FastKernel;
--n_block) {
int dyn_smem_size = max_smem_per_sm / n_block - extra_smem;
dyn_smem_size -= dyn_smem_size % sizeof(T);
voc_parts = ceilDiv(sizeof(T) * params->vocab_size, dyn_smem_size);
}
if (!params->fuse_softmax || voc_parts > kMaxVocabPartForStage1FastKernel) {
params->use_fast_kernel = false;
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
const int max_act_blocks_per_sm = 4;
const int max_act_blocks_per_wave = sm_count * max_act_blocks_per_sm;
const int gridx = params->batch_size * K / 2;
const int max_part_num = (max_act_blocks_per_wave + gridx - 1) / gridx;
voc_parts = min(128, max_part_num);
}
params->voc_parts = voc_parts;
params->max_smem_per_block = max_smem_per_block;
}
template <typename T>
void DispatchComputeVocParts(BeamSearchParams<T> *params) {
switch (params->beam_group_size) {
DISPATCH_COMPUTE_PARTS_K(1);
DISPATCH_COMPUTE_PARTS_K(2);
DISPATCH_COMPUTE_PARTS_K(3);
DISPATCH_COMPUTE_PARTS_K(4);
DISPATCH_COMPUTE_PARTS_K(5);
DISPATCH_COMPUTE_PARTS_K(6);
DISPATCH_COMPUTE_PARTS_K(7);
DISPATCH_COMPUTE_PARTS_K(8);
DISPATCH_COMPUTE_PARTS_K(9);
DISPATCH_COMPUTE_PARTS_K(10);
DISPATCH_COMPUTE_PARTS_K(11);
DISPATCH_COMPUTE_PARTS_K(12);
DISPATCH_COMPUTE_PARTS_K(13);
DISPATCH_COMPUTE_PARTS_K(14);
DISPATCH_COMPUTE_PARTS_K(15);
DISPATCH_COMPUTE_PARTS_K(16);
}
}
template <typename T>
__global__ void update_beam_search_params_kernel(BeamSearchParams<T> params) {
int bb_id = blockIdx.y;
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
const bool finished = params.beam_finished[bb_id];
const int seq_len = params.seq_lens[bb_id];
if (bb_id >= params.beam_width * params.batch_size) {
return;
}
if (finished || seq_len < 0) {
return;
}
const int beam_group_size = params.beam_group_size;
const int max_seq_len = params.max_seq_len;
const int dec_stride = params.dec_stride;
const int batch_group_id = bb_id / beam_group_size;
const int max_dec_len = params.max_dec_lens[bb_id];
const int src_beam = params.parent_ids[bb_id];
const int step = params.step_ids[bb_id];
const int *block_tables = params.block_tables;
int *block_tables_out = params.block_tables_out;
const int *cache_ids = params.beam_cache_ids;
int *cache_ids_out = params.cache_ids_out;
const int *next_tokens = params.next_tokens;
const int beam_group_sub_id = bb_id % beam_group_size;
// const int src_bb_id = batch_group_id * beam_group_size + src_beam;
if (time_step < min(max_seq_len, seq_len + 1)) {
const unsigned int block_tables_tgt_offset =
batch_group_id * beam_group_size * max_seq_len +
beam_group_sub_id * max_seq_len + time_step;
const unsigned int block_tables_src_offset =
batch_group_id * beam_group_size * max_seq_len +
src_beam * max_seq_len + time_step;
block_tables_out[block_tables_tgt_offset] =
block_tables[block_tables_src_offset];
if (time_step < min(step + 1, max_dec_len)) {
const unsigned int cache_ids_tgt_offset =
batch_group_id * beam_group_size * dec_stride +
beam_group_sub_id * dec_stride + time_step;
const unsigned int cache_ids_src_offset =
batch_group_id * beam_group_size * dec_stride +
src_beam * dec_stride + time_step;
cache_ids_out[cache_ids_tgt_offset] =
(time_step == step) ? next_tokens[bb_id]
: cache_ids[cache_ids_src_offset];
}
}
}
template <typename T>
__global__ void update_stop_flags(BeamSearchParams<T> params) {
int bid = blockIdx.x;
const int beam_width = params.beam_width;
const int beam_group_size = params.beam_group_size;
const bool *beam_finished = params.beam_finished + beam_width * bid;
bool *stop_flags = params.stop_flags + beam_width * bid;
bool finished = true;
if (threadIdx.x == 0 && !stop_flags[0]) {
#pragma unroll
for (int i = 0; i < beam_width; i += beam_group_size) {
finished &= beam_finished[i];
}
if (finished) {
#pragma unroll
for (int i = 0; i < beam_width; i++) {
stop_flags[i] = true;
}
}
}
}
template <typename T>
void updateBeamSearchParams(BeamSearchParams<T> *params, cudaStream_t stream) {
const dim3 block(32);
const dim3 grid((params->max_seq_len + block.x - 1) / block.x,
params->batch_size * params->beam_width);
update_beam_search_params_kernel<<<grid, block, 0, stream>>>(*params);
const dim3 grid_2(params->batch_size);
update_stop_flags<<<grid_2, 1, 0, stream>>>(*params);
}
/*****
In order to adapt to the model structure of 5.2 without
adding while op and without affecting the speed. Use a 'fake inplace' method
here. Not elegant but useful ︸_︸.
*****/
std::vector<paddle::Tensor> BeamSearchSoftmax(const paddle::Tensor &logits,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_flags, // inplace
const paddle::Tensor &end_ids,
const paddle::Tensor &step_ids,
const paddle::Tensor &max_dec_lens,
const paddle::Tensor &block_tables, // inplace
const paddle::Tensor &cum_scores, // inplace
const paddle::Tensor &beam_cache_ids, // inplace
const paddle::Tensor &beam_hyps, // inplace
const paddle::Tensor &beam_hyps_score, // inplace
const paddle::Tensor &beam_finished, // inplace
const paddle::Tensor &beam_width,
const paddle::Tensor &beam_group_num,
const paddle::Tensor &length_penalty,
const paddle::Tensor &diversity_penalty,
bool fuse_softmax,
bool early_stop) {
std::vector<int64_t> logits_shape = logits.shape();
// logits_shape
auto cu_stream = logits.stream();
int beam_width_scalar;
cudaMemcpyAsync(&beam_width_scalar,
beam_width.data<int>(),
sizeof(int),
cudaMemcpyDeviceToHost,
cu_stream);
int beam_group_num_scalar;
cudaMemcpyAsync(&beam_group_num_scalar,
beam_group_num.data<int>(),
sizeof(int),
cudaMemcpyDeviceToHost,
cu_stream);
int beam_batch_size = logits_shape[0];
int batch_size = beam_batch_size / beam_width_scalar;
int vocab_size = logits_shape[1];
const int max_seq_len = block_tables.dims()[1];
// liuzichang: In some cases, the length of Tensor is longer than max_dec_lens
const int dec_stride = beam_hyps.dims()[1];
const int end_ids_len = end_ids.dims()[0];
const int beam_group_size = beam_width_scalar / beam_group_num_scalar;
auto next_tokens = paddle::full({logits_shape[0], 1}, 0, end_ids.type(),
paddle::GPUPlace());
auto parent_ids = paddle::full({logits_shape[0], 1}, 0, end_ids.type(),
paddle::GPUPlace());
auto cum_scores_ori = paddle::empty(cum_scores.shape(), logits.type(),
paddle::GPUPlace());
auto beam_cache_ids_ori = paddle::empty(beam_cache_ids.shape(), end_ids.type(),
paddle::GPUPlace());
auto block_tables_ori = paddle::empty(block_tables.shape(), end_ids.type(),
paddle::GPUPlace());
cudaMemcpyAsync(cum_scores_ori.mutable_data<float>(),
cum_scores.data<float>(),
sizeof(float)*cum_scores.numel(),
cudaMemcpyDeviceToDevice,
cu_stream);
cudaMemcpyAsync(beam_cache_ids_ori.mutable_data<int>(),
beam_cache_ids.data<int>(),
sizeof(int)*beam_cache_ids.numel(),
cudaMemcpyDeviceToDevice,
cu_stream);
cudaMemcpyAsync(block_tables_ori.mutable_data<int>(),
block_tables.data<int>(),
sizeof(int)*block_tables.numel(),
cudaMemcpyDeviceToDevice,
cu_stream);
const int tmp_size = batch_size * beam_group_size * beam_group_size * 2;
auto tmp_topk_id = paddle::full({tmp_size}, 0, end_ids.type(),
paddle::GPUPlace());
auto tmp_topk_val = paddle::full({tmp_size}, 0.0, logits.type(),
paddle::GPUPlace());
BeamSearchParams<float> params;
params.batch_size = batch_size;
params.beam_width = beam_width_scalar;
params.beam_group_size = beam_group_size;
params.vocab_size = vocab_size;
params.dec_stride = dec_stride;
params.max_seq_len = max_seq_len;
params.end_ids_len = end_ids_len;
params.fuse_softmax = fuse_softmax;
params.early_stop = early_stop;
// Only Read
params.step_ids = step_ids.data<int>();
params.seq_lens = seq_lens.data<int>();
params.max_dec_lens = max_dec_lens.data<int>();
params.end_ids = end_ids.data<int>();
params.length_penalty = length_penalty.data<float>();
params.diversity_penalty = diversity_penalty.data<float>();
params.cum_scores = cum_scores_ori.data<float>();
params.block_tables = block_tables_ori.data<int>();
params.beam_cache_ids = beam_cache_ids_ori.data<int>();
// Write
params.logits = const_cast<float *>(logits.data<float>());
params.cache_ids_out = const_cast<int *>(beam_cache_ids.data<int>());
params.block_tables_out = const_cast<int *>(block_tables.data<int>());
params.cum_scores_out = const_cast<float *>(cum_scores.data<float>());
params.beam_hyps_out = const_cast<int *>(beam_hyps.data<int>());
params.beam_hyps_score_out = const_cast<float *>(beam_hyps_score.data<float>());
params.beam_finished = const_cast<bool *>(beam_finished.data<bool>());
params.stop_flags = const_cast<bool *>(stop_flags.data<bool>());
params.next_tokens = const_cast<int *>(next_tokens.data<int>());
params.parent_ids = const_cast<int *>(parent_ids.data<int>());
params.tmp_ids = tmp_topk_id.data<int>();
params.tmp_vals = tmp_topk_val.data<float>();
DispatchComputeVocParts<float>(&params);
// allocate workspace
const int tmp_id_val_size =
batch_size * beam_group_size * beam_group_size * 2;
const int packed_top_kmd_size =
fuse_softmax ? 2 * 2 * beam_group_size + 2 : 2 * 2 * beam_group_size;
const int tmp_stage1_to_stage2_size =
batch_size * beam_group_size * params.voc_parts * packed_top_kmd_size;
const int workspace_size = tmp_id_val_size * 2 + tmp_stage1_to_stage2_size;
auto wsp_buffer_tensor = paddle::full({workspace_size}, 0, logits.type(),
paddle::GPUPlace());
params.tmp_ids = reinterpret_cast<int *>(wsp_buffer_tensor.data<float>());
params.tmp_vals = wsp_buffer_tensor.data<float>() + tmp_id_val_size;
params.tmp_buffer = wsp_buffer_tensor.data<float>() + 2 * tmp_id_val_size;
for (int beam_group_idx = 0; beam_group_idx < beam_group_num_scalar;
++beam_group_idx) {
if (beam_group_num_scalar == 1) {
invokeTopkSoftMax<float, false>(
&params, beam_group_idx, cu_stream);
} else {
invokeTopkSoftMax<float, true>(
&params, beam_group_idx, cu_stream);
}
}
updateBeamSearchParams<float>(&params, cu_stream);
return {next_tokens, parent_ids};
}
std::vector<std::vector<int64_t>> BeamSearchSoftmaxShape(
const std::vector<int64_t> &logits,
const std::vector<int64_t> &seq_lens,
const std::vector<int64_t> &stop_flags, // inplace
const std::vector<int64_t> &end_ids,
const std::vector<int64_t> &step_ids,
const std::vector<int64_t> &max_dec_lens,
const std::vector<int64_t> &block_tables, // inplace
const std::vector<int64_t> &cum_scores, // inplace
const std::vector<int64_t> &beam_cache_ids, // inplace
const std::vector<int64_t> &beam_hyps, // inplace
const std::vector<int64_t> &beam_hyps_score, // inplace
const std::vector<int64_t> &beam_finished, // inplace
const std::vector<int64_t> &beam_width,
const std::vector<int64_t> &beam_group_num,
const std::vector<int64_t> &length_penalty,
const std::vector<int64_t> &diversity_penalty) {
std::vector<int64_t> next_tokens = {logits[0],1};
std::vector<int64_t> parent_ids = {logits[0],1};
return {next_tokens,parent_ids};
}
std::vector<paddle::DataType> BeamSearchSoftmaxDtype(
const paddle::DataType &logits,
const paddle::DataType &seq_lens,
const paddle::DataType &stop_flags, // inplace
const paddle::DataType &end_ids,
const paddle::DataType &step_ids,
const paddle::DataType &max_dec_lens,
const paddle::DataType &block_tables, // inplace
const paddle::DataType &cum_scores, // inplace
const paddle::DataType &beam_cache_ids, // inplace
const paddle::DataType &beam_hyps, // inplace
const paddle::DataType &beam_hyps_score, // inplace
const paddle::DataType &beam_finished, // inplace
const paddle::DataType &beam_width,
const paddle::DataType &beam_group_num,
const paddle::DataType &length_penalty,
const paddle::DataType &diversity_penalty) {
return {paddle::DataType::INT32, paddle::DataType::INT32};
}
PD_BUILD_STATIC_OP(beam_search_softmax)
.Inputs({"logits", "seq_lens", "stop_flags", "end_ids", "step_ids", "max_dec_lens", "block_tables"
, "cum_scores", "beam_cache_ids", "beam_hyps", "beam_hyps_score", "beam_finished"
, "beam_width", "beam_group_num", "length_penalty", "diversity_penalty"})
.Outputs({"next_tokens", "parent_ids"})
.Attrs({"fuse_softmax: bool",
"early_stop: bool"})
.SetKernelFn(PD_KERNEL(BeamSearchSoftmax))
.SetInferShapeFn(PD_INFER_SHAPE(BeamSearchSoftmaxShape))
.SetInferDtypeFn(PD_INFER_DTYPE(BeamSearchSoftmaxDtype));