[Feature] support top_k_top_p sampling (#2753)

* support top_k_top_p sampling

* fix

* add api param

* add api para

* fix

* fix

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
Sunny-bot1
2025-07-10 11:58:58 +08:00
committed by GitHub
parent b0f525955c
commit e45050cae3
15 changed files with 501 additions and 53 deletions

View File

@@ -279,7 +279,8 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, IdType* top_k_arr,
uint32_t d, uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
@@ -287,7 +288,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, flo
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
extern __shared__ __align__(
@@ -479,7 +480,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
@@ -497,6 +498,183 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
float max_val = 0;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
__syncthreads();
}
if (tx == 0) {
temp_storage.max_val = max_val;
}
__syncthreads();
return temp_storage.max_val;
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
} block_prim;
struct {
float max_val;
float min_val;
union {
struct {
float values[2];
};
struct {
int counts[2];
};
struct {
ValueCount<float> pairs[2];
};
} block_aggregate;
};
};
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
temp_storage.max_val = 0;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
// - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
do {
double pivot_0 = (high + 2 * low) / 3;
double pivot_1 = (2 * high + low) / 3;
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
min_gt_low = high;
max_le_high = low;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0_pair[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1_pair[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
__syncthreads();
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
__syncthreads();
}
min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
temp_storage.min_val = min_gt_low;
temp_storage.max_val = max_le_high;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pairs[0];
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pairs[1];
min_gt_low = temp_storage.min_val;
max_le_high = temp_storage.max_val;
if (aggregate_gt_pivot_1.count >= k) {
low = pivot_1;
sum_low = float(aggregate_gt_pivot_1.value);
} else if (aggregate_gt_pivot_0.count >= k) {
low = pivot_0;
high = min(pivot_1, max_le_high);
sum_low = float(aggregate_gt_pivot_0.value);
} else {
high = min(pivot_0, max_le_high);
}
} while (min_gt_low != max_le_high);
normalizer = ptx_rcp(max(sum_low, 1e-8));
pivot = low;
}
// normalize
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
@@ -529,7 +707,7 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
@@ -540,7 +718,7 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
@@ -556,4 +734,26 @@ cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
});
}
} // namespace sampling
template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
});
}
} // namespace sampling