Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"
std::vector<paddle::Tensor> TopPSamplingReject(const paddle::Tensor &probs,
const paddle::Tensor &top_p,
int seed) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
uint64_t philox_seed = seed;
uint64_t philox_offset = 0;
auto cu_stream = probs.stream();
// need_batch_random
if (seed == -1) {
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(probs.place()));
auto gen_cuda = dev_ctx->GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(32 * batch_size);
philox_seed = seed_offset.first;
philox_offset = seed_offset.second;
}
auto samples =
paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place());
cudaError_t status;
status = sampling::TopKTopPSamplingFromProb<float, int64_t>(
const_cast<float *>(probs.data<float>()), samples.data<int64_t>(),
batch_size, top_p.data<float>(), vocab_size,
true, philox_seed, philox_offset, cu_stream);
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return {samples};
}
std::vector<std::vector<int64_t>>
TopPSamplingRejectInferShape(const std::vector<int64_t> &probs_shape,
const std::vector<int64_t> &top_p_shape) {
int64_t bs = probs_shape[0];
return {{bs, 1}};
}
std::vector<paddle::DataType>
TopPSamplingRejectInferDtype(const paddle::DataType &probs_dtype,
const paddle::DataType &top_p_shape) {
return {paddle::DataType::INT64};
}
PD_BUILD_STATIC_OP(rejection_top_p_sampling)
.Inputs({"probs", "top_p"})
.Outputs({"samples"})
.Attrs({"seed: int"})
.SetKernelFn(PD_KERNEL(TopPSamplingReject))
.SetInferShapeFn(PD_INFER_SHAPE(TopPSamplingRejectInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(TopPSamplingRejectInferDtype));

View File

@@ -0,0 +1,559 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This code is partially inspired by and references the implementation found
// in FlashInfer.Specifically, the implementation of Top-p Sampling
// functionality in this code is inspired by the logic of FlashInfers
// flashinfer.sampling.top_p_sampling_from_probs . For more details on
// FlashInfers documentation, please refer to:
// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html
#pragma once
#include <cub/block/block_adjacent_difference.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_scan.cuh>
#include <numeric>
#include "sample_kernels/utils.cuh"
namespace sampling {
using namespace cub;
#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \
if (compute_capacity.first >= 8) { \
constexpr uint32_t BLOCK_THREADS = 1024; \
__VA_ARGS__ \
} else { \
constexpr uint32_t BLOCK_THREADS = 512; \
__VA_ARGS__ \
}
constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
#endif
template <typename T> struct Pair {
T value;
int count;
__device__ Pair operator+(const Pair &other) const {
return {value + other.value, count + other.count};
}
__device__ Pair &operator+=(const Pair &other) {
value += other.value;
count += other.count;
return *this;
}
};
template <typename T>
struct ValueCount {
T value;
int count;
__device__ ValueCount operator+(const ValueCount& other) const {
return {value + other.value, count + other.count};
}
__device__ ValueCount& operator+=(const ValueCount& other) {
value += other.value;
count += other.count;
return *this;
}
};
struct BoolDiffOp {
__device__ __forceinline__ bool operator()(const bool &lhs,
const bool &rhs) const {
return lhs != rhs;
}
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM>
struct SamplingTempStorage {
union {
float deterministic_scan[BLOCK_THREADS / 32];
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
} block_prim;
struct {
int32_t sampled_id;
int32_t last_valid_id;
float max_val;
union {
float value;
ValueCount<float> pair;
} block_aggregate;
};
};
/*!
* \brief Deterministic inclusive scan implementation, use Belloch scan
* algorithm. \note This implementation is slower than the cub::BlockScan, but
* it is deterministic.
*/
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
__device__ __forceinline__ void
DeterministicInclusiveSum(const T *in_data, T *out_data,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
REDUCE_ALGORITHM> *temp_storage) {
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
T thread_data[VEC_SIZE];
T thread_sum = 0;
#pragma unroll
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
thread_sum += in_data[i];
thread_data[i] = thread_sum;
}
T thread_exclusive_prefix_sum = thread_sum;
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset *= 2) {
T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
thread_exclusive_prefix_sum += tmp;
}
}
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
threadIdx.x | 0xffffffff);
if (threadIdx.x % 32 == 31) {
thread_exclusive_prefix_sum = 0;
}
#pragma unroll
for (uint32_t offset = 16; offset >= 1; offset /= 2) {
T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
}
if ((threadIdx.x + 1) % (offset * 2) == offset) {
thread_exclusive_prefix_sum = tmp;
}
}
smem_prefix_sum[threadIdx.x / 32] = warp_sum;
__syncthreads();
if (threadIdx.x < 32) {
T warp_exclusive_prefix_sum =
(threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset *= 2) {
T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
warp_exclusive_prefix_sum += tmp;
}
}
if (threadIdx.x % 32 == 31) {
warp_exclusive_prefix_sum = 0;
}
#pragma unroll
for (uint32_t offset = 16; offset >= 1; offset /= 2) {
T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
if ((threadIdx.x + 1) % (offset * 2) == 0) {
warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
}
if ((threadIdx.x + 1) % (offset * 2) == offset) {
warp_exclusive_prefix_sum = tmp;
}
}
if (threadIdx.x < BLOCK_THREADS / 32) {
smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
}
}
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < VEC_SIZE; ++i) {
out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
thread_exclusive_prefix_sum + thread_data[i];
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
__device__ __forceinline__ void DeviceSamplingFromProb(
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
float& aggregate,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
const uint32_t tx = threadIdx.x;
float prob_greater_than_threshold[VEC_SIZE];
float inclusive_cdf[VEC_SIZE];
bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
}
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
if (tx == 0) {
temp_storage->block_aggregate.value = aggregate_local;
}
__syncthreads();
aggregate_local = temp_storage->block_aggregate.value;
if (aggregate + aggregate_local > u) {
if constexpr (DETERMINISTIC) {
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
prob_greater_than_threshold, inclusive_cdf, temp_storage);
} else {
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
__syncthreads();
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
greater_than_u[j] = (inclusive_cdf[j] + aggregate > u) && valid[j];
}
bool greater_than_u_diff[VEC_SIZE];
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
__syncthreads();
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (greater_than_u_diff[j]) {
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
}
__syncthreads();
}
// update the last valid index
int valid_index[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (valid[j]) {
valid_index[j] = (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
} else {
valid_index[j] = -1;
}
}
int max_valid_index =
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
.Reduce(valid_index, cub::Max());
if (tx == 0 && max_valid_index != -1) {
temp_storage->last_valid_id = max_valid_index;
}
__syncthreads();
aggregate += aggregate_local;
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, float* top_p_arr,
uint32_t d, uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
const uint32_t k = top_p_arr[row_idx] == 0 ? 1 : 20;
const float p = top_p_arr[row_idx] == 0 ? 1e-6 : top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.sampled_id;
if (sampled_id == d) {
// NOTE(Zihao): this would happen when u is very close to 1
// and the sum of probabilities is smaller than u
// In this case, we use the last valid index as the sampled id
sampled_id = temp_storage.last_valid_id;
}
double pivot_0 = probs[row_idx * d + sampled_id];
double pivot_1 = (pivot_0 + high) / 2;
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
}
__syncthreads();
aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair;
}
if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1.count < k && aggregate_gt_pivot_1.value < p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0.value;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1.value;
}
} while (low < high);
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
}
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, uint32_t d,
uint64_t philox_seed, uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(philox_seed, bx, philox_offset, &state);
const uint32_t row_idx = bx;
float top_p = top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
float q = 1;
double low = 0, high = 1.f;
int sampled_id;
do {
temp_storage.sampled_id = d;
__syncthreads();
float u = curand_uniform(&state) * q;
aggregate = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
if (aggregate > u) {
break;
}
}
__syncthreads();
sampled_id = temp_storage.sampled_id;
if (sampled_id == d) {
// NOTE(Zihao): this would happen when u is very close to 1
// and the sum of probabilities is smaller than u
// In this case, we use the last valid index as the sampled id
sampled_id = temp_storage.last_valid_id;
}
double pivot_0 = probs[row_idx * d + sampled_id];
double pivot_1 = (pivot_0 + high) / 2;
float aggregate_gt_pivot_0 = 0, aggregate_gt_pivot_1 = 0;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0;
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
}
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
}
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
}
__syncthreads();
aggregate_gt_pivot_1 = temp_storage.block_aggregate.value;
}
if (aggregate_gt_pivot_0 < top_p) {
// case 1: pivot_0 accepted
break;
}
if (aggregate_gt_pivot_1 < top_p) {
// case 2: pivot_0 rejected, pivot_1 accepted
low = pivot_0;
high = pivot_1;
q = aggregate_gt_pivot_0;
} else {
// case 3: pivot_0 rejected, pivot_1 rejected
low = pivot_1;
q = aggregate_gt_pivot_1;
}
} while (low < high);
__syncthreads();
if (tx == 0) {
output[bx] = sampled_id;
}
}
template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t smem_size =
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
})});
return cudaSuccess;
}
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
});
}
} // namespace sampling

View File

@@ -0,0 +1,269 @@
// Copyright © 2024 PaddlePaddle Name. All Rights Reserved.
//
// This code is partially inspired by and references the implementation found in
// FlashInfer. Specifically, the implementation of Top-p Sampling functionality
// in this code is inspired by the logic of FlashInfers
// flashinfer.sampling.top_p_sampling_from_probs function. For more details on
// FlashInfers documentation, please refer to:
// https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_sampling_from_probs.html#flashinfer-sampling-top-p-sampling-from_probs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_device_runtime_api.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
/******************* utils *******************/
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#ifndef NDEBUG
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
}
#else
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
}
#endif
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
if (deterministic) { \
constexpr bool DETERMINISTIC = true; \
__VA_ARGS__ \
} else { \
constexpr bool DETERMINISTIC = false; \
__VA_ARGS__ \
}
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
switch (aligned_vec_size) { \
case 16: { \
constexpr size_t ALIGNED_VEC_SIZE = 16; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t ALIGNED_VEC_SIZE = 8; \
__VA_ARGS__ \
break; \
} \
case 4: { \
constexpr size_t ALIGNED_VEC_SIZE = 4; \
__VA_ARGS__ \
break; \
} \
case 2: { \
constexpr size_t ALIGNED_VEC_SIZE = 2; \
__VA_ARGS__ \
break; \
} \
case 1: { \
constexpr size_t ALIGNED_VEC_SIZE = 1; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
throw std::invalid_argument(err_msg.str()); \
} \
}
/******************* vec_t<float> *******************/
#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
template <typename float_t, size_t vec_size> struct vec_t {
SAMPLING_INLINE float_t &operator[](size_t i);
SAMPLING_INLINE const float_t &operator[](size_t i) const;
SAMPLING_INLINE void fill(float_t val);
SAMPLING_INLINE void load(const float_t *ptr);
SAMPLING_INLINE void store(float_t *ptr) const;
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src);
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr);
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const;
SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src);
SAMPLING_INLINE float_t *ptr();
};
// float x 1
template <> struct vec_t<float, 1> {
float data;
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE void fill(float val);
SAMPLING_INLINE void load(const float *ptr);
SAMPLING_INLINE void store(float *ptr) const;
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 1> &src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
};
SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
SAMPLING_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
SAMPLING_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
*dst = *src;
}
// float x 2
template <> struct vec_t<float, 2> {
float2 data;
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE void fill(float val);
SAMPLING_INLINE void load(const float *ptr);
SAMPLING_INLINE void store(float *ptr) const;
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 2> &src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
};
SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
data = make_float2(val, val);
}
SAMPLING_INLINE void vec_t<float, 2>::load(const float *ptr) {
data = *((float2 *)ptr);
}
SAMPLING_INLINE void vec_t<float, 2>::store(float *ptr) const {
*((float2 *)ptr) = data;
}
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
// float x 4 or more
template <size_t vec_size> struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE void fill(float val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = make_float4(val, val, val, val);
}
}
SAMPLING_INLINE void load(const float *ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4 *)ptr)[i];
}
}
SAMPLING_INLINE void store(float *ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)ptr)[i] = data[i];
}
}
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)dst)[i] = ((float4 *)src)[i];
}
}
};
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
const src_float_t* src_ptr) {
if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
dst.load(src_ptr);
} else {
vec_t<src_float_t, vec_size> tmp;
tmp.load(src_ptr);
dst.cast_from(tmp);
}
}
inline std::pair<int, int> GetCudaComputeCapability() {
int device_id = 0;
cudaGetDevice(&device_id);
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id);
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id);
return std::make_pair(major, minor);
}
/******************* math *******************/
__forceinline__ __device__ float ptx_rcp(float x) {
float y;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
}
template <typename T1, typename T2>
__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
return (x + y - 1) / y;
}