Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,559 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This code is partially inspired by and references the implementation found
// in FlashInfer.Specifically, the implementation of Top-p Sampling
// functionality in this code is inspired by the logic of FlashInfers
// flashinfer.sampling.top_p_sampling_from_probs . For more details on
// FlashInfers documentation, please refer to:
// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html
#pragma once
#include <cub/block/block_adjacent_difference.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
#include <numeric>
#include "sample_kernels/utils.cuh"
namespace sampling {
using namespace cub;
#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
__VA_ARGS__ \
} else { \
constexpr uint32_t BLOCK_THREADS = 512; \
__VA_ARGS__ \
}
constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
#endif
template <typename T> struct Pair {
T value;
int count;
__device__ Pair operator+(const Pair &other) const {
return {value + other.value, count + other.count};
}
__device__ Pair &operator+=(const Pair &other) {
value += other.value;
count += other.count;
return *this;
}
};
template <typename T>
struct ValueCount {
T value;
int count;
__device__ ValueCount operator+(const ValueCount& other) const {
return {value + other.value, count + other.count};
}
__device__ ValueCount& operator+=(const ValueCount& other) {
value += other.value;
count += other.count;
return *this;
}
};
struct BoolDiffOp {
__device__ __forceinline__ bool operator()(const bool &lhs,
const bool &rhs) const {
return lhs != rhs;
}
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM>
struct SamplingTempStorage {
union {
float deterministic_scan[BLOCK_THREADS / 32];
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
} block_prim;
struct {
int32_t sampled_id;
int32_t last_valid_id;
float max_val;
union {
float value;
ValueCount<float> pair;
} block_aggregate;
};
};
/*!
* \brief Deterministic inclusive scan implementation, use Belloch scan
* algorithm. \note This implementation is slower than the cub::BlockScan, but
* it is deterministic.
*/
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
__device__ __forceinline__ void
DeterministicInclusiveSum(const T *in_data, T *out_data,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
REDUCE_ALGORITHM> *temp_storage) {
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
T thread_data[VEC_SIZE];
T thread_sum = 0;
#pragma unroll
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
thread_sum += in_data[i];
thread_data[i] = thread_sum;
}
T thread_exclusive_prefix_sum = thread_sum;
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset *= 2) {
T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
thread_exclusive_prefix_sum += tmp;
}
}
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
threadIdx.x | 0xffffffff);
if (threadIdx.x % 32 == 31) {
thread_exclusive_prefix_sum = 0;
}
#pragma unroll
for (uint32_t offset = 16; offset >= 1; offset /= 2) {
T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
}
if ((threadIdx.x + 1) % (offset * 2) == offset) {
thread_exclusive_prefix_sum = tmp;
}
}
smem_prefix_sum[threadIdx.x / 32] = warp_sum;
__syncthreads();
if (threadIdx.x < 32) {
T warp_exclusive_prefix_sum =
(threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset *= 2) {
T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
warp_exclusive_prefix_sum += tmp;
}
}
if (threadIdx.x % 32 == 31) {
warp_exclusive_prefix_sum = 0;
}
#pragma unroll
for (uint32_t offset = 16; offset >= 1; offset /= 2) {
T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
}
if ((threadIdx.x + 1) % (offset * 2) == offset) {
warp_exclusive_prefix_sum = tmp;
}
}
if (threadIdx.x < BLOCK_THREADS / 32) {
smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
}
}
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
thread_exclusive_prefix_sum + thread_data[i];
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
__device__ __forceinline__ void DeviceSamplingFromProb(
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
float& aggregate,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
const uint32_t tx = threadIdx.x;
float prob_greater_than_threshold[VEC_SIZE];
float inclusive_cdf[VEC_SIZE];
bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
}
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
if (tx == 0) {
temp_storage->block_aggregate.value = aggregate_local;
}
__syncthreads();
aggregate_local = temp_storage->block_aggregate.value;
if (aggregate + aggregate_local > u) {
if constexpr (DETERMINISTIC) {
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
prob_greater_than_threshold, inclusive_cdf, temp_storage);
} else {
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
__syncthreads();
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j];
}
bool greater_than_u_diff[VEC_SIZE];
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
__syncthreads();
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (greater_than_u_diff[j]) {
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
}
__syncthreads();
}
// update the last valid index
int valid_index[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (valid[j]) {
valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
} else {
valid_index[j] = -1;
}
}
int max_valid_index =
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
.Reduce(valid_index, cub::Max());
if (tx == 0 && max_valid_index != -1) {
temp_storage->last_valid_id = max_valid_index;
}
__syncthreads();
aggregate += aggregate_local;
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
uint32_t d, uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.sampled_id;
if (sampled_id == d) {
// NOTE(Zihao): this would happen when u is very close to 1
// and the sum of probabilities is smaller than u
// In this case, we use the last valid index as the sampled id
sampled_id = temp_storage.last_valid_id;
}
double pivot_0 = probs[row_idx * d + sampled_id];
double pivot_1 = (pivot_0 + high) / 2;
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
}
__syncthreads();
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair;
}
if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0.value;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1.value;
}
} while (low < high);
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
}
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, uint32_t d,
uint64_t philox_seed, uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
float top_p = top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.sampled_id;
if (sampled_id == d) {
// NOTE(Zihao): this would happen when u is very close to 1
// and the sum of probabilities is smaller than u
// In this case, we use the last valid index as the sampled id
sampled_id = temp_storage.last_valid_id;
}
double pivot_0 = probs[row_idx * d + sampled_id];
double pivot_1 = (pivot_0 + high) / 2;
float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0;
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
}
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
}
__syncthreads();
aggregate_gt_pivot_1 = temp_storage.block_aggregate.value;
}
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1;
}
} while (low < high);
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
}
}
template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t smem_size =
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
})});
return cudaSuccess;
}
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
});
}
} // namespace sampling