mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] refactor cutlass moe and optimize flash attention (#5361)
* [Metax] refactor moe and flash attention backend --------- Co-authored-by: zhangchenyi_dl <16219492+zhangchenyidl@user.noreply.gitee.com>
This commit is contained in:
@@ -48,14 +48,15 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
|
||||
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||
#endif
|
||||
|
||||
template <typename T> struct Pair {
|
||||
template <typename T>
|
||||
struct Pair {
|
||||
T value;
|
||||
int count;
|
||||
|
||||
__device__ Pair operator+(const Pair &other) const {
|
||||
__device__ Pair operator+(const Pair& other) const {
|
||||
return {value + other.value, count + other.count};
|
||||
}
|
||||
__device__ Pair &operator+=(const Pair &other) {
|
||||
__device__ Pair& operator+=(const Pair& other) {
|
||||
value += other.value;
|
||||
count += other.count;
|
||||
return *this;
|
||||
@@ -78,22 +79,25 @@ struct ValueCount {
|
||||
};
|
||||
|
||||
struct BoolDiffOp {
|
||||
__device__ __forceinline__ bool operator()(const bool &lhs,
|
||||
const bool &rhs) const {
|
||||
__device__ __forceinline__ bool operator()(const bool& lhs,
|
||||
const bool& rhs) const {
|
||||
return lhs != rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct SamplingTempStorage {
|
||||
union {
|
||||
float deterministic_scan[BLOCK_THREADS / 32];
|
||||
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||
TempStorage reduce_value_count;
|
||||
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
|
||||
} block_prim;
|
||||
struct {
|
||||
@@ -112,14 +116,17 @@ struct SamplingTempStorage {
|
||||
* algorithm. \note This implementation is slower than the cub::BlockScan, but
|
||||
* it is deterministic.
|
||||
*/
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
|
||||
__device__ __forceinline__ void
|
||||
DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM> *temp_storage) {
|
||||
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename T>
|
||||
__device__ __forceinline__ void DeterministicInclusiveSum(
|
||||
const T* in_data,
|
||||
T* out_data,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||
temp_storage) {
|
||||
T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
||||
T thread_data[VEC_SIZE];
|
||||
T thread_sum = 0;
|
||||
#pragma unroll
|
||||
@@ -138,8 +145,8 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
}
|
||||
}
|
||||
|
||||
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
|
||||
threadIdx.x | 0xffffffff);
|
||||
T warp_sum = __shfl_sync(
|
||||
0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff);
|
||||
if (threadIdx.x % 32 == 31) {
|
||||
thread_exclusive_prefix_sum = 0;
|
||||
}
|
||||
@@ -197,12 +204,21 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
bool DETERMINISTIC,
|
||||
typename Predicate>
|
||||
__device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
|
||||
uint32_t i,
|
||||
uint32_t d,
|
||||
Predicate pred,
|
||||
float u,
|
||||
vec_t<float, VEC_SIZE> prob_vec,
|
||||
float& aggregate,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||
temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
float prob_greater_than_threshold[VEC_SIZE];
|
||||
float inclusive_cdf[VEC_SIZE];
|
||||
@@ -212,14 +228,14 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
||||
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
||||
}
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
float aggregate_local =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||
.Sum(prob_greater_than_threshold);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce)
|
||||
.Sum(prob_greater_than_threshold);
|
||||
#else
|
||||
float aggregate_local =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage->block_aggregate.value = aggregate_local;
|
||||
@@ -229,14 +245,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
|
||||
if (aggregate + aggregate_local > u) {
|
||||
if constexpr (DETERMINISTIC) {
|
||||
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
|
||||
DeterministicInclusiveSum<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM>(
|
||||
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
||||
} else {
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||
temp_storage->block_prim.scan)
|
||||
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
|
||||
#else
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||
temp_storage->block_prim.scan)
|
||||
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
||||
#endif
|
||||
|
||||
@@ -250,28 +271,35 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
|
||||
bool greater_than_u_diff[VEC_SIZE];
|
||||
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#endif
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#else
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#endif
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft<VEC_SIZE>(
|
||||
greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#endif
|
||||
#else
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads<VEC_SIZE>(
|
||||
greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#endif
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
if (greater_than_u_diff[j]) {
|
||||
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
||||
atomicMin(&(temp_storage->sampled_id),
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -287,9 +315,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
valid_index[j] = -1;
|
||||
}
|
||||
}
|
||||
int max_valid_index =
|
||||
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
|
||||
.Reduce(valid_index, cub::Max());
|
||||
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce_int)
|
||||
.Reduce(valid_index, cub::Max());
|
||||
if (tx == 0 && max_valid_index != -1) {
|
||||
temp_storage->last_valid_id = max_valid_index;
|
||||
}
|
||||
@@ -297,15 +325,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
aggregate += aggregate_local;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, IdType* top_k_arr,
|
||||
uint32_t d, uint64_t philox_seed,
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs,
|
||||
IdType* output,
|
||||
float* top_p_arr,
|
||||
IdType* top_k_arr,
|
||||
uint32_t d,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
@@ -315,12 +347,12 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
float aggregate;
|
||||
@@ -336,12 +368,22 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
||||
DeviceSamplingFromProb<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM,
|
||||
DETERMINISTIC>(
|
||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
||||
i,
|
||||
d,
|
||||
[&](float x) { return x > low; },
|
||||
u,
|
||||
probs_vec,
|
||||
aggregate,
|
||||
&temp_storage);
|
||||
if (aggregate > u) {
|
||||
break;
|
||||
}
|
||||
@@ -362,28 +404,29 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_0[j] = {(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1[j] = {(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#else
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
||||
@@ -391,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#else
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
||||
@@ -427,14 +470,19 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC, typename DType, typename IdType>
|
||||
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, uint32_t d,
|
||||
uint64_t philox_seed, uint64_t philox_offset) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopPSamplingFromProbKernel(DType* probs,
|
||||
IdType* output,
|
||||
float* top_p_arr,
|
||||
uint32_t d,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
@@ -442,12 +490,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
const uint32_t row_idx = bx;
|
||||
float top_p = top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
float aggregate;
|
||||
@@ -463,12 +511,22 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
||||
DeviceSamplingFromProb<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM,
|
||||
DETERMINISTIC>(
|
||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
||||
i,
|
||||
d,
|
||||
[&](float x) { return x > low; },
|
||||
u,
|
||||
probs_vec,
|
||||
aggregate,
|
||||
&temp_storage);
|
||||
if (aggregate > u) {
|
||||
break;
|
||||
}
|
||||
@@ -489,7 +547,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||
@@ -499,12 +558,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#else
|
||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
||||
@@ -512,12 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#else
|
||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
||||
@@ -546,9 +609,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename TempStorage>
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data,
|
||||
uint32_t row_idx,
|
||||
uint32_t d,
|
||||
TempStorage& temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
vec_t<float, VEC_SIZE> in_data_vec;
|
||||
@@ -557,21 +624,24 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
in_data_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
in_data_vec.cast_load(in_data + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
float in_data_[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
in_data_[j] = in_data_vec[j];
|
||||
}
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(in_data_, cub::Max()));
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(in_data_, cub::Max()));
|
||||
#else
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
#endif
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -585,10 +655,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct RenormTempStorage {
|
||||
union {
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||
TempStorage reduce_value_count;
|
||||
} block_prim;
|
||||
struct {
|
||||
float max_val;
|
||||
@@ -607,24 +679,33 @@ struct RenormTempStorage {
|
||||
};
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType,typename IdType>
|
||||
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
DType* renormed_prob,uint32_t d) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void MinPSamplingFromProbKernel(DType* probs,
|
||||
const float* min_p_arr,
|
||||
DType* renormed_prob,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
|
||||
const uint32_t row_idx = bx;
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||
float max_val = GetMaxValue<
|
||||
VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
REDUCE_ALGORITHM,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
float pivot = max_val * p;
|
||||
|
||||
@@ -633,7 +714,8 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -641,42 +723,51 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.store(renormed_prob + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs,
|
||||
DType* renormed_prob,
|
||||
IdType* top_k_arr,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
#else
|
||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
#endif
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
if (k < d) {
|
||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||
uint8_t smem_renorm[];
|
||||
extern __shared__ __align__(alignof(
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>)) uint8_t smem_renorm[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(
|
||||
smem_renorm);
|
||||
temp_storage.max_val = 0;
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
float max_val =
|
||||
GetMaxValue<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
|
||||
double low = 0, high = max_val;
|
||||
float min_gt_low, max_le_high;
|
||||
float sum_low = 1;
|
||||
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
||||
// loop invariant:
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
|
||||
// p <= high} loop invariant:
|
||||
// - f(low) >= k, f(high) < k
|
||||
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||
// stopping condition: min_gt_low == max_le_high
|
||||
@@ -692,55 +783,65 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE],
|
||||
probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0_pair[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
(probs_vec[j] > pivot_0 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1_pair[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
(probs_vec[j] > pivot_1 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
|
||||
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
if (probs_vec[j] > low &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||
}
|
||||
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
if (probs_vec[j] <= high &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
max_le_high = max(max_le_high, probs_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0_pair);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0_pair);
|
||||
#else
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1_pair);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1_pair);
|
||||
#else
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
#endif
|
||||
__syncthreads();
|
||||
}
|
||||
min_gt_low =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
__syncthreads();
|
||||
max_le_high =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||
@@ -774,23 +875,29 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
|
||||
tx * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.store(renormed_prob + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaError_t TopPSamplingFromProb(T* probs,
|
||||
IdType* output,
|
||||
uint32_t batch_size,
|
||||
const T* top_p_val,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
@@ -799,99 +906,139 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
void* args[] = {
|
||||
&probs, &output, &top_p_val, &d, &philox_seed, &philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE,
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel =
|
||||
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
||||
auto kernel = TopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
||||
smem_size, stream));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <typename T,typename IdType>
|
||||
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t MinPSamplingFromProb(T* probs,
|
||||
const T* min_p_arr,
|
||||
T* renormed_prob,
|
||||
uint32_t batch_size,
|
||||
uint32_t d, bool deterministic,
|
||||
cudaStream_t stream = 0){
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
|
||||
void* args[] = {&probs, &min_p_arr, &renormed_prob, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE,
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel =
|
||||
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T,IdType>;
|
||||
auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
||||
smem_size, stream));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaError_t TopKTopPSamplingFromProb(T* probs,
|
||||
IdType* output,
|
||||
uint32_t batch_size,
|
||||
const T* top_p_val,
|
||||
const IdType* top_k_val,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
void* args[] = {&probs,
|
||||
&output,
|
||||
&top_p_val,
|
||||
&top_k_val,
|
||||
&d,
|
||||
&philox_seed,
|
||||
&philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(
|
||||
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
||||
uint32_t batch_size, uint32_t d,
|
||||
cudaError_t TopKRenormProb(DType* probs,
|
||||
DType* renormed_prob,
|
||||
IdType* top_k_arr,
|
||||
uint32_t batch_size,
|
||||
uint32_t d,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DType,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
} // namespace sampling
|
||||
|
||||
Reference in New Issue
Block a user