[Metax] refactor cutlass moe and optimize flash attention (#5361)

* [Metax] refactor moe and flash attention backend
---------

Co-authored-by: zhangchenyi_dl <16219492+zhangchenyidl@user.noreply.gitee.com>
This commit is contained in:
Neil Zhu
2025-12-10 17:15:17 +08:00
committed by GitHub
parent fbc9bce1e9
commit 4403a21d4b
19 changed files with 3087 additions and 1727 deletions

View File

@@ -48,14 +48,15 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
#endif
template <typename T> struct Pair {
template <typename T>
struct Pair {
T value;
int count;
__device__ Pair operator+(const Pair &other) const {
__device__ Pair operator+(const Pair& other) const {
return {value + other.value, count + other.count};
}
__device__ Pair &operator+=(const Pair &other) {
__device__ Pair& operator+=(const Pair& other) {
value += other.value;
count += other.count;
return *this;
@@ -78,22 +79,25 @@ struct ValueCount {
};
struct BoolDiffOp {
__device__ __forceinline__ bool operator()(const bool &lhs,
const bool &rhs) const {
__device__ __forceinline__ bool operator()(const bool& lhs,
const bool& rhs) const {
return lhs != rhs;
}
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
template <uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM>
struct SamplingTempStorage {
union {
float deterministic_scan[BLOCK_THREADS / 32];
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
TempStorage reduce_value_count;
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
} block_prim;
struct {
@@ -112,14 +116,17 @@ struct SamplingTempStorage {
* algorithm. \note This implementation is slower than the cub::BlockScan, but
* it is deterministic.
*/
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
template <uint32_t VEC_SIZE,
uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
__device__ __forceinline__ void
DeterministicInclusiveSum(const T *in_data, T *out_data,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
REDUCE_ALGORITHM> *temp_storage) {
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
BlockReduceAlgorithm REDUCE_ALGORITHM,
typename T>
__device__ __forceinline__ void DeterministicInclusiveSum(
const T* in_data,
T* out_data,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
temp_storage) {
T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
T thread_data[VEC_SIZE];
T thread_sum = 0;
#pragma unroll
@@ -138,8 +145,8 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
}
}
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
threadIdx.x | 0xffffffff);
T warp_sum = __shfl_sync(
0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff);
if (threadIdx.x % 32 == 31) {
thread_exclusive_prefix_sum = 0;
}
@@ -197,12 +204,21 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
template <uint32_t VEC_SIZE,
uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
bool DETERMINISTIC,
typename Predicate>
__device__ __forceinline__ void DeviceSamplingFromProb(
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
uint32_t i,
uint32_t d,
Predicate pred,
float u,
vec_t<float, VEC_SIZE> prob_vec,
float& aggregate,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
temp_storage) {
const uint32_t tx = threadIdx.x;
float prob_greater_than_threshold[VEC_SIZE];
float inclusive_cdf[VEC_SIZE];
@@ -212,14 +228,14 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
}
#ifdef PADDLE_WITH_COREX
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum(prob_greater_than_threshold);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage->block_prim.reduce)
.Sum(prob_greater_than_threshold);
#else
float aggregate_local =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage->block_prim.reduce)
.Sum<VEC_SIZE>(prob_greater_than_threshold);
#endif
if (tx == 0) {
temp_storage->block_aggregate.value = aggregate_local;
@@ -229,14 +245,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
if (aggregate + aggregate_local > u) {
if constexpr (DETERMINISTIC) {
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
DeterministicInclusiveSum<VEC_SIZE,
BLOCK_THREADS,
SCAN_ALGORITHM,
REDUCE_ALGORITHM>(
prob_greater_than_threshold, inclusive_cdf, temp_storage);
} else {
#ifdef PADDLE_WITH_COREX
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
temp_storage->block_prim.scan)
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
#else
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
temp_storage->block_prim.scan)
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
#endif
@@ -250,28 +271,35 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
bool greater_than_u_diff[VEC_SIZE];
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
#ifdef PADDLE_WITH_COREX
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
#endif
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
BlockAdjacentDifference<bool, BLOCK_THREADS>(
temp_storage->block_prim.adj_diff)
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
#else
#ifdef PADDLE_WITH_COREX
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
BlockAdjacentDifference<bool, BLOCK_THREADS>(
temp_storage->block_prim.adj_diff)
.SubtractLeft<VEC_SIZE>(
greater_than_u, greater_than_u_diff, BoolDiffOp());
#endif
#else
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
BlockAdjacentDifference<bool, BLOCK_THREADS>(
temp_storage->block_prim.adj_diff)
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#else
BlockAdjacentDifference<bool, BLOCK_THREADS>(
temp_storage->block_prim.adj_diff)
.FlagHeads<VEC_SIZE>(
greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
#endif
#endif
__syncthreads();
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
if (greater_than_u_diff[j]) {
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
atomicMin(&(temp_storage->sampled_id),
(i * BLOCK_THREADS + tx) * VEC_SIZE + j);
}
}
__syncthreads();
@@ -287,9 +315,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
valid_index[j] = -1;
}
}
int max_valid_index =
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
.Reduce(valid_index, cub::Max());
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage->block_prim.reduce_int)
.Reduce(valid_index, cub::Max());
if (tx == 0 && max_valid_index != -1) {
temp_storage->last_valid_id = max_valid_index;
}
@@ -297,15 +325,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
aggregate += aggregate_local;
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, IdType* top_k_arr,
uint32_t d, uint64_t philox_seed,
template <uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
__global__ void TopKTopPSamplingFromProbKernel(DType* probs,
IdType* output,
float* top_p_arr,
IdType* top_k_arr,
uint32_t d,
uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
@@ -315,12 +347,12 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
const float p = top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
@@ -336,12 +368,22 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DeviceSamplingFromProb<VEC_SIZE,
BLOCK_THREADS,
SCAN_ALGORITHM,
REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
i,
d,
[&](float x) { return x > low; },
u,
probs_vec,
aggregate,
&temp_storage);
if (aggregate > u) {
break;
}
@@ -362,28 +404,29 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_0[j] = {(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1[j] = {(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0);
#else
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
#endif
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
@@ -391,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1);
#else
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
#endif
if (tx == 0) {
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
@@ -427,14 +470,19 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
float* top_p_arr, uint32_t d,
uint64_t philox_seed, uint64_t philox_offset) {
template <uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
__global__ void TopPSamplingFromProbKernel(DType* probs,
IdType* output,
float* top_p_arr,
uint32_t d,
uint64_t philox_seed,
uint64_t philox_offset) {
const uint32_t batch_size = gridDim.x;
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
curandStatePhilox4_32_10_t state;
@@ -442,12 +490,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
const uint32_t row_idx = bx;
float top_p = top_p_arr[row_idx];
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
vec_t<float, VEC_SIZE> probs_vec;
float aggregate;
@@ -463,12 +511,22 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
DeviceSamplingFromProb<VEC_SIZE,
BLOCK_THREADS,
SCAN_ALGORITHM,
REDUCE_ALGORITHM,
DETERMINISTIC>(
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
i,
d,
[&](float x) { return x > low; },
u,
probs_vec,
aggregate,
&temp_storage);
if (aggregate > u) {
break;
}
@@ -489,7 +547,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
@@ -499,12 +558,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_0);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_0 +=
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_0);
#else
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
aggregate_gt_pivot_0 +=
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_0);
#endif
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
@@ -512,12 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
__syncthreads();
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_1);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_1 +=
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum(probs_gt_pivot_1);
#else
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
aggregate_gt_pivot_1 +=
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(probs_gt_pivot_1);
#endif
if (tx == 0) {
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
@@ -546,9 +609,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
template <uint32_t VEC_SIZE,
uint32_t BLOCK_THREADS,
BlockReduceAlgorithm REDUCE_ALGORITHM,
typename TempStorage>
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
__device__ __forceinline__ float GetMaxValue(float* in_data,
uint32_t row_idx,
uint32_t d,
TempStorage& temp_storage) {
const uint32_t tx = threadIdx.x;
vec_t<float, VEC_SIZE> in_data_vec;
@@ -557,21 +624,24 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
in_data_vec.cast_load(in_data + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
}
#ifdef PADDLE_WITH_COREX
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(in_data_, cub::Max()));
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
max_val = max(max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(in_data_, cub::Max()));
#else
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
max_val = max(max_val,
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
#endif
__syncthreads();
}
@@ -585,10 +655,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
struct RenormTempStorage {
union {
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_value_count;
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce;
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
reduce_int;
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
TempStorage reduce_value_count;
} block_prim;
struct {
float max_val;
@@ -607,24 +679,33 @@ struct RenormTempStorage {
};
};
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType,typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
DType* renormed_prob,uint32_t d) {
template <uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs,
const float* min_p_arr,
DType* renormed_prob,
uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
const uint32_t row_idx = bx;
extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
extern __shared__ __align__(alignof(
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
auto& temp_storage = reinterpret_cast<
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
float max_val = GetMaxValue<
VEC_SIZE,
BLOCK_THREADS,
REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float pivot = max_val * p;
@@ -633,7 +714,8 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
(i * BLOCK_THREADS + tx) * VEC_SIZE);
}
#pragma unroll
@@ -641,42 +723,51 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
probs_vec.store(renormed_prob + row_idx * d +
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
template <uint32_t BLOCK_THREADS,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
typename DType,
typename IdType>
__global__ void TopKRenormProbKernel(DType* probs,
DType* renormed_prob,
IdType* top_k_arr,
uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
const uint32_t row_idx = bx;
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
#ifdef PADDLE_WITH_COREX
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
#else
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
#endif
vec_t<float, VEC_SIZE> probs_vec;
if (k < d) {
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
uint8_t smem_renorm[];
extern __shared__ __align__(alignof(
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>)) uint8_t smem_renorm[];
auto& temp_storage =
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(
smem_renorm);
temp_storage.max_val = 0;
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float max_val =
GetMaxValue<VEC_SIZE,
BLOCK_THREADS,
REDUCE_ALGORITHM,
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
double low = 0, high = max_val;
float min_gt_low, max_le_high;
float sum_low = 1;
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
// loop invariant:
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
// p <= high} loop invariant:
// - f(low) >= k, f(high) < k
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
// stopping condition: min_gt_low == max_le_high
@@ -692,55 +783,65 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d +
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE],
probs_gt_pivot_1_pair[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_gt_pivot_0_pair[j] = {
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
(probs_vec[j] > pivot_0 &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
probs_gt_pivot_1_pair[j] = {
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
(probs_vec[j] > pivot_1 &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
if (probs_vec[j] > low &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
min_gt_low = min(min_gt_low, probs_vec[j]);
}
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
if (probs_vec[j] <= high &&
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
max_le_high = max(max_le_high, probs_vec[j]);
}
}
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0_pair);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_0_pair);
#else
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
aggregate_gt_pivot_0 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
#endif
__syncthreads();
#ifdef PADDLE_WITH_COREX
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1_pair);
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum(probs_gt_pivot_1_pair);
#else
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
aggregate_gt_pivot_1 +=
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce_value_count)
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
#endif
__syncthreads();
}
min_gt_low =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(min_gt_low, cub::Min());
__syncthreads();
max_le_high =
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
temp_storage.block_prim.reduce)
.Reduce(max_le_high, cub::Max());
if (tx == 0) {
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
@@ -774,23 +875,29 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
tx * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
probs_vec.store(renormed_prob + row_idx * d +
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
}
template <typename T, typename IdType>
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaError_t TopPSamplingFromProb(T* probs,
IdType* output,
uint32_t batch_size,
const T* top_p_val,
uint32_t d,
bool deterministic,
uint64_t philox_seed,
uint64_t philox_offset,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
@@ -799,99 +906,139 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val,
&d, &philox_seed, &philox_offset};
void* args[] = {
&probs, &output, &top_p_val, &d, &philox_seed, &philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
vec_size,
VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
auto kernel = TopPSamplingFromProbKernel<BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
T,
IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
}
template <typename T,typename IdType>
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
template <typename T, typename IdType>
cudaError_t MinPSamplingFromProb(T* probs,
const T* min_p_arr,
T* renormed_prob,
uint32_t batch_size,
uint32_t d, bool deterministic,
cudaStream_t stream = 0){
uint32_t d,
bool deterministic,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
const uint32_t smem_size =
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
void* args[] = {&probs, &min_p_arr, &renormed_prob, &d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
vec_size,
VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T,IdType>;
auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
T,
IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
}
template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
uint32_t d, bool deterministic,
uint64_t philox_seed, uint64_t philox_offset,
cudaError_t TopKTopPSamplingFromProb(T* probs,
IdType* output,
uint32_t batch_size,
const T* top_p_val,
const IdType* top_k_val,
uint32_t d,
bool deterministic,
uint64_t philox_seed,
uint64_t philox_offset,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
const uint32_t smem_size =
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
&d, &philox_seed, &philox_offset};
void* args[] = {&probs,
&output,
&top_p_val,
&top_k_val,
&d,
&philox_seed,
&philox_offset};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
vec_size,
VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
T,
IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
});
}
template <typename DType, typename IdType>
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
uint32_t batch_size, uint32_t d,
cudaError_t TopKRenormProb(DType* probs,
DType* renormed_prob,
IdType* top_k_arr,
uint32_t batch_size,
uint32_t d,
cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
const uint32_t smem_size =
sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
auto kernel = TopKRenormProbKernel<BLOCK_THREADS,
REDUCE_ALGO,
VEC_SIZE,
DType,
IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel(
(void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
});
}
} // namespace sampling
} // namespace sampling

View File

@@ -23,221 +23,235 @@
#include <cuda_device_runtime_api.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
/******************* utils *******************/
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#ifndef NDEBUG
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
<< ") " << __FILE__ << ": line " << __LINE__ \
<< " at function " << STR(func) << std::endl; \
return e; \
} \
}
#else
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
#define CUDA_CALL(func, ...) \
{ \
cudaError_t e = (func); \
if (e != cudaSuccess) { \
return e; \
} \
}
#endif
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
if (deterministic) { \
constexpr bool DETERMINISTIC = true; \
__VA_ARGS__ \
} else { \
constexpr bool DETERMINISTIC = false; \
__VA_ARGS__ \
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
if (deterministic) { \
constexpr bool DETERMINISTIC = true; \
__VA_ARGS__ \
} else { \
constexpr bool DETERMINISTIC = false; \
__VA_ARGS__ \
}
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
switch (aligned_vec_size) { \
case 16: { \
constexpr size_t ALIGNED_VEC_SIZE = 16; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t ALIGNED_VEC_SIZE = 8; \
__VA_ARGS__ \
break; \
} \
case 4: { \
constexpr size_t ALIGNED_VEC_SIZE = 4; \
__VA_ARGS__ \
break; \
} \
case 2: { \
constexpr size_t ALIGNED_VEC_SIZE = 2; \
__VA_ARGS__ \
break; \
} \
case 1: { \
constexpr size_t ALIGNED_VEC_SIZE = 1; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
throw std::invalid_argument(err_msg.str()); \
} \
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
switch (aligned_vec_size) { \
case 16: { \
constexpr size_t ALIGNED_VEC_SIZE = 16; \
__VA_ARGS__ \
break; \
} \
case 8: { \
constexpr size_t ALIGNED_VEC_SIZE = 8; \
__VA_ARGS__ \
break; \
} \
case 4: { \
constexpr size_t ALIGNED_VEC_SIZE = 4; \
__VA_ARGS__ \
break; \
} \
case 2: { \
constexpr size_t ALIGNED_VEC_SIZE = 2; \
__VA_ARGS__ \
break; \
} \
case 1: { \
constexpr size_t ALIGNED_VEC_SIZE = 1; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
throw std::invalid_argument(err_msg.str()); \
} \
}
/******************* vec_t<float> *******************/
#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
template <typename float_t, size_t vec_size> struct vec_t {
SAMPLING_INLINE float_t &operator[](size_t i);
SAMPLING_INLINE const float_t &operator[](size_t i) const;
template <typename float_t, size_t vec_size>
struct vec_t {
SAMPLING_INLINE float_t& operator[](size_t i);
SAMPLING_INLINE const float_t& operator[](size_t i) const;
SAMPLING_INLINE void fill(float_t val);
SAMPLING_INLINE void load(const float_t *ptr);
SAMPLING_INLINE void store(float_t *ptr) const;
SAMPLING_INLINE void load(const float_t* ptr);
SAMPLING_INLINE void store(float_t* ptr) const;
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src);
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr);
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const;
SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src);
SAMPLING_INLINE float_t *ptr();
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src);
template <typename T>
SAMPLING_INLINE void cast_load(const T* ptr);
template <typename T>
SAMPLING_INLINE void cast_store(T* ptr) const;
SAMPLING_INLINE static void memcpy(float_t* dst, const float_t* src);
SAMPLING_INLINE float_t* ptr();
};
// float x 1
template <> struct vec_t<float, 1> {
template <>
struct vec_t<float, 1> {
float data;
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
SAMPLING_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
SAMPLING_INLINE void fill(float val);
SAMPLING_INLINE void load(const float *ptr);
SAMPLING_INLINE void store(float *ptr) const;
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 1> &src) {
SAMPLING_INLINE void load(const float* ptr);
SAMPLING_INLINE void store(float* ptr) const;
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, 1>& src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
template <typename T>
SAMPLING_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
template <typename T>
SAMPLING_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
};
SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
SAMPLING_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
SAMPLING_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
SAMPLING_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
SAMPLING_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
*dst = *src;
}
// float x 2
template <> struct vec_t<float, 2> {
template <>
struct vec_t<float, 2> {
float2 data;
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(&data))[i];
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
SAMPLING_INLINE const float& operator[](size_t i) const {
return ((const float*)(&data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
SAMPLING_INLINE void fill(float val);
SAMPLING_INLINE void load(const float *ptr);
SAMPLING_INLINE void store(float *ptr) const;
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 2> &src) {
SAMPLING_INLINE void load(const float* ptr);
SAMPLING_INLINE void store(float* ptr) const;
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, 2>& src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
template <typename T>
SAMPLING_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
template <typename T>
SAMPLING_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
};
SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
data = make_float2(val, val);
}
SAMPLING_INLINE void vec_t<float, 2>::load(const float *ptr) {
data = *((float2 *)ptr);
SAMPLING_INLINE void vec_t<float, 2>::load(const float* ptr) {
data = *((float2*)ptr);
}
SAMPLING_INLINE void vec_t<float, 2>::store(float *ptr) const {
*((float2 *)ptr) = data;
SAMPLING_INLINE void vec_t<float, 2>::store(float* ptr) const {
*((float2*)ptr) = data;
}
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
*((float2*)dst) = *((float2*)src);
}
// float x 4 or more
template <size_t vec_size> struct vec_t<float, vec_size> {
template <size_t vec_size>
struct vec_t<float, vec_size> {
float4 data[vec_size / 4];
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
SAMPLING_INLINE const float &operator[](size_t i) const {
return ((const float *)(data))[i];
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
SAMPLING_INLINE const float& operator[](size_t i) const {
return ((const float*)(data))[i];
}
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
SAMPLING_INLINE void fill(float val) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = make_float4(val, val, val, val);
}
}
SAMPLING_INLINE void load(const float *ptr) {
SAMPLING_INLINE void load(const float* ptr) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
data[i] = ((float4 *)ptr)[i];
data[i] = ((float4*)ptr)[i];
}
}
SAMPLING_INLINE void store(float *ptr) const {
SAMPLING_INLINE void store(float* ptr) const {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)ptr)[i] = data[i];
((float4*)ptr)[i] = data[i];
}
}
template <typename T>
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src) {
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src) {
cast_from_impl(*this, src);
}
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
template <typename T>
SAMPLING_INLINE void cast_load(const T* ptr) {
cast_load_impl(*this, ptr);
}
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
template <typename T>
SAMPLING_INLINE void cast_store(T* ptr) const {
cast_store_impl(ptr, *this);
}
SAMPLING_INLINE static void memcpy(float *dst, const float *src) {
SAMPLING_INLINE static void memcpy(float* dst, const float* src) {
#pragma unroll
for (size_t i = 0; i < vec_size / 4; ++i) {
((float4 *)dst)[i] = ((float4 *)src)[i];
((float4*)dst)[i] = ((float4*)src)[i];
}
}
};
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
const src_float_t* src_ptr) {
const src_float_t* src_ptr) {
if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
dst.load(src_ptr);
} else {
@@ -260,11 +274,16 @@ inline std::pair<int, int> GetCudaComputeCapability() {
__forceinline__ __device__ float ptx_rcp(float x) {
#ifdef PADDLE_WITH_COREX
return __ivcorex_rcpf(x);
#else
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
return __frcp_rn(x);
#else
float y;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
return y;
#endif
#endif
}
template <typename T1, typename T2>

View File

@@ -1,291 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/extension.h>
#include <algorithm>
#include "helper.h"
#define THREADS_PER_BLOCK 128
template <typename T>
struct Converter;
template <>
struct Converter<__half> {
// __half -> float
__device__ static float to_float(__half val) { return __half2float(val); }
// float -> __half
__device__ static __half from_float(float val) {
return __float2half_rn(val);
}
// int -> __half
__device__ static __half from_int(float val) { return __int2half_rn(val); }
};
template <>
struct Converter<__nv_bfloat16> {
// __nv_bfloat16 -> float
__device__ static float to_float(__nv_bfloat16 val) {
return __bfloat162float(val);
}
// float -> __nv_bfloat16
__device__ static __nv_bfloat16 from_float(float val) {
return __float2bfloat16_rn(val);
}
// int -> __nv_bfloat16
__device__ static __nv_bfloat16 from_int(int val) {
return __int2bfloat16_rn(val);
}
};
template <typename T>
__device__ void RotateQKVec4(const T* qk_ptr,
const T* rot_cos_ptr,
const T* rot_sin_ptr,
const int head_num,
const int base_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
VecT qk_vec;
Load(qk_ptr + base_idx, &qk_vec);
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
VecT cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + base_idx + i) =
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
}
}
template <typename T>
__device__ void RotateQKVec4(const T* qk_ptr,
const float* rot_cos_ptr,
const float* rot_sin_ptr,
const int head_num,
const int base_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
using VecF = AlignedVector<float, 4>;
auto to_float = [] __device__(T val) -> float {
return Converter<T>::to_float(val);
};
auto from_float = [] __device__(float val) -> T {
return Converter<T>::from_float(val);
};
VecT qk_vec;
Load(qk_ptr + base_idx, &qk_vec);
VecF rot_half_vec = {-to_float(qk_vec[1]),
to_float(qk_vec[0]),
-to_float(qk_vec[3]),
to_float(qk_vec[2])};
VecF cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
rot_half_vec[i] * sin_vec[i]);
}
}
// qk and rope have a same type
template <typename T>
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
const T* k,
const T* rot_cos,
const T* rot_sin,
const int q_num_elements,
const int k_num_elements,
const int q_head_num,
const int k_head_num,
const int head_dim,
T* q_out,
T* k_out) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
int head_dim_idx = idx % head_dim;
if (idx < q_num_elements) {
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
}
if (idx < k_num_elements) {
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
}
}
// rope dtype is float32
template <typename T>
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
const T* k,
const float* rot_cos,
const float* rot_sin,
const int q_num_elements,
const int k_num_elements,
const int q_head_num,
const int k_head_num,
const int head_dim,
T* q_out,
T* k_out) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
int head_dim_idx = idx % head_dim;
if (idx < q_num_elements) {
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
}
if (idx < k_num_elements) {
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
}
}
template <paddle::DataType D>
void ApplyRopeKernel(const paddle::Tensor& q,
const paddle::Tensor& k,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin,
paddle::Tensor& q_out,
paddle::Tensor& k_out) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const auto q_num_elements = q.numel();
const auto k_num_elements = k.numel();
const auto q_shape = q.shape();
const auto k_shape = k.shape();
const auto dims = q_shape.size();
const auto q_head_num = q_shape[dims - 2];
const auto k_head_num = k_shape[dims - 2];
const auto head_dim = q_shape.back();
int block_num =
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
(THREADS_PER_BLOCK * 4);
auto stream = q.stream();
if (q.dtype() == rot_cos.dtype()) {
DispatchApplyRopeVec4Kernel<DataType_>
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const DataType_*>(q.data<data_t>()),
reinterpret_cast<const DataType_*>(k.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
q_num_elements,
k_num_elements,
q_head_num,
k_head_num,
head_dim,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
DispatchApplyRopeVec4Kernel<DataType_>
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const DataType_*>(q.data<data_t>()),
reinterpret_cast<const DataType_*>(k.data<data_t>()),
reinterpret_cast<const float*>(rot_cos.data<float>()),
reinterpret_cast<const float*>(rot_sin.data<float>()),
q_num_elements,
k_num_elements,
q_head_num,
k_head_num,
head_dim,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
} else {
PD_THROW("Unsupported qk dtype and rope dtype.");
}
}
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
const paddle::Tensor& k,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin) {
auto q_shape = q.shape();
auto cos_shape = rot_cos.shape();
auto q_out = paddle::empty_like(q);
auto k_out = paddle::empty_like(k);
if (q.numel() == 0 || k.numel() == 0) {
return {q_out, k_out};
}
PADDLE_ENFORCE_EQ(
q_shape.back() % 2,
0,
"The last dimension (head_dim) of qk must be an even number "
"for RoPE, but got %d",
q_shape.back());
PADDLE_ENFORCE_EQ(q_shape.size(),
cos_shape.size(),
"The shape size of cos mismatches the shape size of q, "
"expect %d but got %d",
q_shape.size(),
cos_shape.size());
PADDLE_ENFORCE_EQ(q_shape.back(),
cos_shape.back(),
"The shape.back() of cos mismatches the shape.back() of q, "
"expect %d but got %d",
q_shape.back(),
cos_shape.back());
auto input_type = q.dtype();
switch (input_type) {
case paddle::DataType::BFLOAT16:
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
q, k, rot_cos, rot_sin, q_out, k_out);
break;
case paddle::DataType::FLOAT16:
ApplyRopeKernel<paddle::DataType::FLOAT16>(
q, k, rot_cos, rot_sin, q_out, k_out);
break;
default:
PD_THROW("Only support qk dtype of BF16 and F16");
}
return {q_out, k_out};
}
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
const std::vector<int64_t>& q_shape,
const std::vector<int64_t>& k_shape,
const std::vector<int64_t>& cos_shape,
const std::vector<int64_t>& sin_shape) {
return {q_shape, k_shape, cos_shape, sin_shape};
}
std::vector<paddle::DataType> ApplyRopeInferDtype(
const paddle::DataType& q_dtype,
const paddle::DataType& k_dtype,
const paddle::DataType& cos_dtype,
const paddle::DataType& sin_dtype) {
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
}
PD_BUILD_OP(apply_rope)
.Inputs({"q", "k", "rot_cos", "rot_sin"})
.Outputs({"q_out", "k_out"})
.SetKernelFn(PD_KERNEL(ApplyRope))
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));

View File

@@ -0,0 +1,329 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/extension.h>
#include <algorithm>
#include "helper.h"
template <typename T>
struct Converter;
template <>
struct Converter<__half> {
// __half -> float
__device__ static float to_float(__half val) { return __half2float(val); }
// float -> __half
__device__ static __half from_float(float val) {
return __float2half_rn(val);
}
// int -> __half
__device__ static __half from_int(float val) { return __int2half_rn(val); }
};
template <>
struct Converter<__nv_bfloat16> {
// __nv_bfloat16 -> float
__device__ static float to_float(__nv_bfloat16 val) {
return __bfloat162float(val);
}
// float -> __nv_bfloat16
__device__ static __nv_bfloat16 from_float(float val) {
return __float2bfloat16_rn(val);
}
// int -> __nv_bfloat16
__device__ static __nv_bfloat16 from_int(int val) {
return __int2bfloat16_rn(val);
}
};
struct ApplyRopeQKVParams {
int head_dim;
int token_stride;
int head_stride;
int q_stride;
int kv_stride;
int q_head_offset;
int k_head_offset;
int v_head_offset;
int q_head_num;
int kv_head_num;
};
template <typename T>
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
const T* rot_cos_ptr,
const T* rot_sin_ptr,
const int load_idx,
const int store_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
VecT qk_vec;
Load(qkv_ptr + load_idx, &qk_vec);
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
VecT cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + store_idx + i) =
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
}
}
template <typename T>
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
const float* rot_cos_ptr,
const float* rot_sin_ptr,
const int load_idx,
const int store_idx,
const int rot_base_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
using VecF = AlignedVector<float, 4>;
auto to_float = [] __device__(T val) -> float {
return Converter<T>::to_float(val);
};
auto from_float = [] __device__(float val) -> T {
return Converter<T>::from_float(val);
};
VecT qk_vec;
Load(qkv_ptr + load_idx, &qk_vec);
VecF rot_half_vec = {-to_float(qk_vec[1]),
to_float(qk_vec[0]),
-to_float(qk_vec[3]),
to_float(qk_vec[2])};
VecF cos_vec, sin_vec;
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < 4; ++i) {
*(out + store_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
rot_half_vec[i] * sin_vec[i]);
}
}
template <typename T>
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
const int load_idx,
const int store_idx,
T* out) {
using VecT = AlignedVector<T, 4>;
VecT v_vec;
Load(qkv_ptr + load_idx, &v_vec);
Store(v_vec, out + store_idx);
}
template <typename T, typename WeightType>
__global__ void DispatchApplyRopeQKVVec4Kernel(const T* qkv,
const WeightType* rot_cos,
const WeightType* rot_sin,
ApplyRopeQKVParams param,
T* q_out,
T* k_out,
T* v_out) {
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * 4;
int rot_idx = token_idx * param.head_dim + head_dim_idx;
int load_idx, store_idx;
if (head_idx < param.q_head_num && head_dim_idx < param.head_dim) { // q
load_idx = token_idx * param.token_stride +
(head_idx + param.q_head_offset) * param.head_stride +
head_dim_idx;
store_idx =
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, q_out);
}
if (head_idx < param.kv_head_num && head_dim_idx < param.head_dim) { // kv
load_idx = token_idx * param.token_stride +
(head_idx + param.k_head_offset) * param.head_stride +
head_dim_idx;
store_idx =
token_idx * param.kv_stride + head_idx * param.head_dim + head_dim_idx;
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, k_out);
load_idx = token_idx * param.token_stride +
(head_idx + param.v_head_offset) * param.head_stride +
head_dim_idx;
StoreValue(qkv, load_idx, store_idx, v_out);
}
}
template <paddle::DataType D>
void ApplyRopeQKVKernel(const paddle::Tensor& qkv,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin,
const int q_head_num,
const int kv_head_num,
const int head_dim,
paddle::Tensor& q_out,
paddle::Tensor& k_out,
paddle::Tensor& v_out) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int all_num_elements = qkv.numel();
const int all_num_head = q_head_num + 2 * kv_head_num;
auto stream = qkv.stream();
dim3 block_dims(1, 4, 32);
dim3 grid_dims(all_num_elements / (all_num_head * head_dim), // token
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
block_dims.y, // head
(head_dim + (block_dims.z * 4) - 1) /
(block_dims.z * 4) // dim: load vec4 at a time
);
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
ApplyRopeQKVParams param;
param.head_dim = head_dim;
param.token_stride = all_num_head * head_dim;
param.head_stride = head_dim;
param.q_stride = q_head_num * head_dim;
param.kv_stride = kv_head_num * head_dim;
param.q_head_offset = 0;
param.k_head_offset = q_head_num;
param.v_head_offset = q_head_num + kv_head_num;
param.q_head_num = q_head_num;
param.kv_head_num = kv_head_num;
if (qkv.dtype() == rot_cos.dtype()) {
DispatchApplyRopeQKVVec4Kernel<DataType_, DataType_>
<<<grid_dims, block_dims, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
param,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
DispatchApplyRopeQKVVec4Kernel<DataType_, float>
<<<grid_dims, block_dims, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
reinterpret_cast<const float*>(rot_cos.data<float>()),
reinterpret_cast<const float*>(rot_sin.data<float>()),
param,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
} else {
PD_THROW("Unsupported qk dtype and rope dtype.");
}
}
std::vector<paddle::Tensor> ApplyRopeQKV(const paddle::Tensor& qkv,
const paddle::Tensor& rot_cos,
const paddle::Tensor& rot_sin,
const int q_head_num,
const int kv_head_num,
const int head_dim) {
auto qkv_shape = qkv.shape();
auto token_num = qkv_shape[0];
auto place = qkv.place();
auto dtype = qkv.dtype();
common::DDim q_out_shape, kv_out_shape;
if (rot_cos.shape().size() == 3) {
q_out_shape = {token_num, q_head_num, head_dim};
kv_out_shape = {token_num, kv_head_num, head_dim};
} else {
q_out_shape = {token_num, 1, q_head_num, head_dim};
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
}
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
if (token_num == 0) {
return {q_out, k_out, v_out};
}
PADDLE_ENFORCE_EQ(qkv_shape.back(),
((q_head_num + 2 * kv_head_num) * head_dim),
"The last dimension of qkv [%d] must equal to {(q_head_num "
"+ 2 * kv_head_num) * head_dim [%d].",
qkv_shape.back(),
((q_head_num + 2 * kv_head_num) * head_dim));
PADDLE_ENFORCE_EQ(
head_dim % 2,
0,
"The last dimension (head_dim) of qkv must be an even number "
"for RoPE, but got %d",
head_dim);
PADDLE_ENFORCE_EQ(q_out.shape().back(),
rot_cos.shape().back(),
"The last dimension of cos mismatches that of q, "
"expect %d but got %d",
q_out.shape().back(),
rot_cos.shape().back());
switch (dtype) {
case paddle::DataType::BFLOAT16:
ApplyRopeQKVKernel<paddle::DataType::BFLOAT16>(qkv,
rot_cos,
rot_sin,
q_head_num,
kv_head_num,
head_dim,
q_out,
k_out,
v_out);
break;
case paddle::DataType::FLOAT16:
ApplyRopeQKVKernel<paddle::DataType::FLOAT16>(qkv,
rot_cos,
rot_sin,
q_head_num,
kv_head_num,
head_dim,
q_out,
k_out,
v_out);
break;
default:
PD_THROW("Only support qk dtype of BF16 and F16");
}
return {q_out, k_out, v_out};
}
std::vector<std::vector<int64_t>> ApplyRopeQKVInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& cos_shape,
const std::vector<int64_t>& sin_shape) {
return {qkv_shape, cos_shape, sin_shape};
}
std::vector<paddle::DataType> ApplyRopeQKVInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& cos_dtype,
const paddle::DataType& sin_dtype) {
return {qkv_dtype, cos_dtype, sin_dtype};
}
PD_BUILD_OP(apply_rope_qkv)
.Inputs({"qkv", "rot_cos", "rot_sin"})
.Outputs({"q_out", "k_out", "v_out"})
.Attrs({"q_head_num:int", "kv_head_num:int", "head_dim:int"})
.SetKernelFn(PD_KERNEL(ApplyRopeQKV))
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeQKVInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeQKVInferDtype));

View File

@@ -0,0 +1,477 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <paddle/extension.h>
#include <algorithm>
#include "helper.h"
template <typename T>
struct Converter;
template <>
struct Converter<__half> {
// __half -> float
__device__ static float to_float(__half val) { return __half2float(val); }
// float -> __half
__device__ static __half from_float(float val) {
return __float2half_rn(val);
}
// int -> __half
__device__ static __half from_int(float val) { return __int2half_rn(val); }
};
template <>
struct Converter<__nv_bfloat16> {
// __nv_bfloat16 -> float
__device__ static float to_float(__nv_bfloat16 val) {
return __bfloat162float(val);
}
// float -> __nv_bfloat16
__device__ static __nv_bfloat16 from_float(float val) {
return __float2bfloat16_rn(val);
}
// int -> __nv_bfloat16
__device__ static __nv_bfloat16 from_int(int val) {
return __int2bfloat16_rn(val);
}
};
struct CacheKVWithRopeParams {
int head_dim;
int block_size;
int block_num;
int cache_stride;
int token_stride;
int head_stride;
int q_stride;
int kv_stride;
int q_head_offset;
int k_head_offset;
int v_head_offset;
int q_head_num;
int kv_head_num;
};
template <typename T, int VecSize = 4, bool WriteCache = true>
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
const T* rotary_cos_ptr,
const T* rotary_sin_ptr,
const int load_idx,
const int store_idx,
const int cache_store_idx,
const int rot_base_idx,
T* caches,
T* out) {
using VecT = AlignedVector<T, VecSize>;
VecT qk_vec;
Load(qkv_ptr + load_idx, &qk_vec);
VecT rot_half_vec;
int flag;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
flag = 1 - 2 * (i % 2);
rot_half_vec[i] = -qk_vec[i + flag] * Converter<T>::from_int(flag);
}
VecT cos_vec, sin_vec;
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
T result = qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
*(out + store_idx + i) = result;
if (WriteCache) {
*(caches + cache_store_idx + i) = result;
}
}
}
template <typename T, int VecSize = 4, bool WriteCache = true>
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
const float* rotary_cos_ptr,
const float* rotary_sin_ptr,
const int load_idx,
const int store_idx,
const int cache_store_idx,
const int rot_base_idx,
T* caches,
T* out) {
using VecT = AlignedVector<T, VecSize>;
using VecF = AlignedVector<float, VecSize>;
auto to_float = [] __device__(T val) -> float {
return Converter<T>::to_float(val);
};
auto from_float = [] __device__(float val) -> T {
return Converter<T>::from_float(val);
};
VecT qk_vec;
Load(qkv_ptr + load_idx, &qk_vec);
VecF rot_half_vec;
int flag;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
flag = 1 - 2 * (i % 2);
rot_half_vec[i] = -to_float(qk_vec[i + flag]) * static_cast<float>(flag);
}
VecF cos_vec, sin_vec;
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
T result = from_float(to_float(qk_vec[i]) * cos_vec[i] +
rot_half_vec[i] * sin_vec[i]);
*(out + store_idx + i) = result;
if (WriteCache) {
*(caches + cache_store_idx + i) = result;
}
}
}
template <typename T, int VecSize = 4>
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
const int load_idx,
const int store_idx,
const int cache_store_idx,
T* caches,
T* out) {
using VecT = AlignedVector<T, VecSize>;
VecT v_vec;
Load(qkv_ptr + load_idx, &v_vec);
Store(v_vec, out + store_idx);
Store(v_vec, caches + cache_store_idx);
}
template <typename T, typename WeightType, int VecSize>
__global__ void DispatchCacheKVWithRopeVecKernel(const T* qkv,
T* caches_k,
T* caches_v,
const int* block_tables,
const WeightType* rotary_cos,
const WeightType* rotary_sin,
const int* cu_seqlens_q,
const int* batch_ids_q,
CacheKVWithRopeParams param,
T* q_out,
T* k_out,
T* v_out) {
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * VecSize;
int load_idx, store_idx, cache_store_idx;
int rot_idx = token_idx * param.head_dim + head_dim_idx;
const int batch_idx = *(batch_ids_q + token_idx);
const int inter_batch_token_offset = token_idx - *(cu_seqlens_q + batch_idx);
const int inter_batch_block_idx = inter_batch_token_offset / param.block_size;
const int inter_block_offset = inter_batch_token_offset % param.block_size;
const int block_idx =
*(block_tables + batch_idx * param.block_num + inter_batch_block_idx);
assert(block_idx != -1);
if (head_dim_idx < param.head_dim) {
if (head_idx < param.q_head_num) { // q
load_idx = token_idx * param.token_stride +
(head_idx + param.q_head_offset) * param.head_stride +
head_dim_idx;
store_idx =
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
RotateQKVec<T, VecSize, false>(qkv,
rotary_cos,
rotary_sin,
load_idx,
store_idx,
-1,
rot_idx,
static_cast<T*>(nullptr),
q_out);
}
if (head_idx < param.kv_head_num) { // kv
load_idx = token_idx * param.token_stride +
(head_idx + param.k_head_offset) * param.head_stride +
head_dim_idx;
store_idx = token_idx * param.kv_stride + head_idx * param.head_dim +
head_dim_idx;
cache_store_idx = block_idx * param.cache_stride +
inter_block_offset * param.kv_stride +
head_idx * param.head_dim + head_dim_idx;
// printf("block_idx: %d inter_block_offset: %d cache_store_idx: %d
// param.cache_stride: %d\n", block_idx, inter_block_offset,
// cache_store_idx, param.cache_stride);
RotateQKVec<T, VecSize, true>(qkv,
rotary_cos,
rotary_sin,
load_idx,
store_idx,
cache_store_idx,
rot_idx,
caches_k,
k_out);
load_idx = token_idx * param.token_stride +
(head_idx + param.v_head_offset) * param.head_stride +
head_dim_idx;
StoreValue<T, VecSize>(
qkv, load_idx, store_idx, cache_store_idx, caches_v, v_out);
}
}
}
template <paddle::DataType D, int VecSize = 4>
void CacheKVWithRopeKernel(
const paddle::Tensor& qkv, // token_num, head_num * head_dim
paddle::Tensor&
caches_k, // max_block_num, block_size, kv_head_num, head_dim
paddle::Tensor&
caches_v, // max_block_num, block_size, kv_head_num, head_dim
const paddle::Tensor& block_tables, // bs, block_num
const paddle::Tensor& rotary_cos,
const paddle::Tensor& rotary_sin,
const paddle::Tensor& cu_seqlens_q, // bs + 1
const paddle::Tensor& batch_ids_q, // token_num
const int q_head_num,
const int kv_head_num,
const int head_dim,
const int block_size,
paddle::Tensor& q_out,
paddle::Tensor& k_out,
paddle::Tensor& v_out) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int all_num_elements = qkv.numel();
const int all_num_heads = q_head_num + 2 * kv_head_num;
auto stream = qkv.stream();
dim3 block_dims(1, 4, (head_dim + VecSize - 1) / VecSize);
dim3 grid_dims(all_num_elements / (all_num_heads * head_dim), // token
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
block_dims.y, // head
(head_dim + (block_dims.z * VecSize) - 1) /
(block_dims.z * VecSize) // dim: load Vec at a time
);
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
CacheKVWithRopeParams param;
param.head_dim = head_dim;
param.block_size = block_size;
param.block_num = static_cast<int>(block_tables.shape().back());
param.cache_stride = block_size * kv_head_num * head_dim;
param.token_stride = all_num_heads * head_dim;
param.head_stride = head_dim;
param.q_stride = q_head_num * head_dim;
param.kv_stride = kv_head_num * head_dim;
param.q_head_offset = 0;
param.k_head_offset = q_head_num;
param.v_head_offset = q_head_num + kv_head_num;
param.q_head_num = q_head_num;
param.kv_head_num = kv_head_num;
if (qkv.dtype() == rotary_cos.dtype()) {
DispatchCacheKVWithRopeVecKernel<DataType_, DataType_, VecSize>
<<<grid_dims, block_dims, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
reinterpret_cast<const int*>(block_tables.data<int>()),
reinterpret_cast<const DataType_*>(rotary_cos.data<data_t>()),
reinterpret_cast<const DataType_*>(rotary_sin.data<data_t>()),
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
param,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
} else if (rotary_cos.dtype() == paddle::DataType::FLOAT32) {
DispatchCacheKVWithRopeVecKernel<DataType_, float, VecSize>
<<<grid_dims, block_dims, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
reinterpret_cast<const int*>(block_tables.data<int>()),
reinterpret_cast<const float*>(rotary_cos.data<float>()),
reinterpret_cast<const float*>(rotary_sin.data<float>()),
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
param,
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
} else {
PD_THROW("Unsupported qk dtype and rope dtype.");
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("CUDA Error: %s\n", cudaGetErrorString(err));
}
}
std::vector<paddle::Tensor> CacheKVWithRope(
const paddle::Tensor& qkv, // token_num, head_num * head_dim
paddle::Tensor&
caches_k, // max_block_num, block_size, kv_head_num, head_dim
paddle::Tensor&
caches_v, // max_block_num, block_size, kv_head_num, head_dim
const paddle::Tensor& block_tables, // bs, block_num
const paddle::Tensor& rotary_cos,
const paddle::Tensor& rotary_sin,
const paddle::Tensor& cu_seqlens_q, // bs + 1
const paddle::Tensor& batch_ids_q, // token_num
const int q_head_num,
const int kv_head_num,
const int head_dim,
const int block_size) {
auto qkv_shape = qkv.shape();
auto token_num = qkv_shape[0];
auto place = qkv.place();
auto dtype = qkv.dtype();
common::DDim q_out_shape, kv_out_shape;
if (rotary_cos.shape().size() == 3) {
q_out_shape = {token_num, q_head_num, head_dim};
kv_out_shape = {token_num, kv_head_num, head_dim};
} else {
q_out_shape = {token_num, 1, q_head_num, head_dim};
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
}
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
if (token_num == 0) {
return {q_out, k_out, v_out};
}
PADDLE_ENFORCE_EQ(qkv_shape.back(),
((q_head_num + 2 * kv_head_num) * head_dim),
"The last dimension of qkv [%d] must equal to {(q_head_num "
"+ 2 * kv_head_num) * head_dim [%d].",
qkv_shape.back(),
((q_head_num + 2 * kv_head_num) * head_dim));
PADDLE_ENFORCE_EQ(
head_dim % 2,
0,
"The last dimension (head_dim) of qkv must be an even number "
"for RoPE, but got %d",
head_dim);
PADDLE_ENFORCE_EQ(q_out.shape().back(),
rotary_cos.shape().back(),
"The last dimension of cos mismatches that of q, "
"expect %d but got %d",
q_out.shape().back(),
rotary_cos.shape().back());
switch (dtype) {
case paddle::DataType::BFLOAT16:
CacheKVWithRopeKernel<paddle::DataType::BFLOAT16>(qkv,
caches_k,
caches_v,
block_tables,
rotary_cos,
rotary_sin,
cu_seqlens_q,
batch_ids_q,
q_head_num,
kv_head_num,
head_dim,
block_size,
q_out,
k_out,
v_out);
break;
case paddle::DataType::FLOAT16:
CacheKVWithRopeKernel<paddle::DataType::FLOAT16>(qkv,
caches_k,
caches_v,
block_tables,
rotary_cos,
rotary_sin,
cu_seqlens_q,
batch_ids_q,
q_head_num,
kv_head_num,
head_dim,
block_size,
q_out,
k_out,
v_out);
break;
default:
PD_THROW("Only support qk dtype of BF16 and F16");
}
return {q_out, k_out, v_out};
}
std::vector<std::vector<int64_t>> CacheKVWithRopeInferShape(
const std::vector<int64_t>& qkv_shape,
const std::vector<int64_t>& caches_k_shape,
const std::vector<int64_t>& caches_v_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& cos_shape,
const std::vector<int64_t>& sin_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& batch_ids_q_shape) {
return {qkv_shape,
caches_k_shape,
caches_v_shape,
block_tables_shape,
cos_shape,
sin_shape,
cu_seqlens_q_shape,
batch_ids_q_shape};
}
std::vector<paddle::DataType> CacheKVWithRopeInferDtype(
const paddle::DataType& qkv_dtype,
const paddle::DataType& caches_k_dtype,
const paddle::DataType& caches_v_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& cos_dtype,
const paddle::DataType& sin_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& batch_ids_q_dtype) {
return {qkv_dtype,
caches_k_dtype,
caches_v_dtype,
block_tables_dtype,
cos_dtype,
sin_dtype,
cu_seqlens_q_dtype,
batch_ids_q_dtype};
}
PD_BUILD_OP(cache_kv_with_rope)
.Inputs({"qkv",
"caches_k",
"caches_v",
"block_tables",
"rotary_cos",
"rotary_sin",
"cu_seqlen_q",
"batch_ids_q"})
.Outputs({"q_out", "k_out", "v_out"})
.Attrs(
{"q_head_num:int", "kv_head_num:int", "head_dim:int", "block_size:int"})
.SetKernelFn(PD_KERNEL(CacheKVWithRope))
.SetInferShapeFn(PD_INFER_SHAPE(CacheKVWithRopeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(CacheKVWithRopeInferDtype));

View File

@@ -14,9 +14,10 @@
#pragma once
#include "fused_moe_op.h"
#include "fused_moe_helper.h"
#include "helper.h"
#include "mc_fused_moe_helper.h"
namespace phi {
__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
@@ -42,58 +43,61 @@ void compute_total_rows_before_expert(int* sorted_indices,
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
template <paddle::DataType T,
typename ElementA,
typename ElementB,
typename ElementC>
} // namespace phi
template <paddle::DataType T>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::Tensor& up_gate_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto* output_data = output->data<data_t>();
auto moe_compute =
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
moe_compute.computeFFN(&input,
&gate_weight,
&ffn1_weight,
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
&ffn2_weight,
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
auto moe_compute =
McMoeHelper<data_t, DataType_>(quant_method, &int8_moe_gemm_runner);
moe_compute.computeFFN(
&input,
&gate_weight,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
@@ -107,40 +111,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16,
maca_bfloat16,
int8_t,
maca_bfloat16>(input,
gate_weight,
ffn1_weight,
ffn1_scale,
ffn1_bias,
ffn2_weight,
ffn2_scale,
ffn2_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
// case paddle::DataType::FLOAT16:
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
// gate_weight,
// ffn1_weight,
// ffn1_scale,
// ffn1_bias,
// ffn2_weight,
// ffn2_scale,
// ffn2_bias,
// quant_method,
// moe_topk,
// group_moe,
// norm_topk_prob,
// &output);
// break;
default:
PD_THROW("Only support bf16 for FusedMoeKernel");
PD_THROW("Unsupported data type for FusedMoeKernel");
}
return {output};
}
@@ -148,36 +134,36 @@ std::vector<paddle::Tensor> FusedExpertMoe(
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
return {input_dtype};
}
PD_BUILD_OP(fused_expert_moe)
PD_BUILD_STATIC_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_bias"),
paddle::Optional("ffn2_scale")})
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_bias"),
paddle::Optional("down_proj_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",

View File

@@ -0,0 +1,199 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
namespace phi {
template <typename T>
struct mctlassExDataTraits;
template <>
struct mctlassExDataTraits<maca_bfloat16> {
static constexpr mctlassExDataType type =
mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
};
template <>
struct mctlassExDataTraits<int8_t> {
static constexpr mctlassExDataType type =
mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
};
template <typename T, typename WeightType>
class McMoeGemmRunner {
public:
McMoeGemmRunner() {}
void mc_grouped_gemm_basic_kernel(const T* ptrA,
mctlassExOrder_t majorA,
const WeightType* ptrB,
mctlassExOrder_t majorB,
const T* ptrScale,
const T* ptrBias,
T* ptrC,
mctlassExOrder_t majorC,
const int* ptrSegInd,
int* ptrMNumTilesInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
int k, // hidden_size
mcStream_t stream) {
mctlassExHandle_t handle;
mctlassExHandleCreate(&handle);
mctlassExDataType DataType_ = mctlassExDataTraits<T>::type;
mctlassExDataType WeightType_ = mctlassExDataTraits<WeightType>::type;
mctlassExMatrixLayout_t matLayoutA;
mctlassExMatrixLayout_t matLayoutB;
mctlassExMatrixLayout_t matLayoutC;
// mat A: (m, k)
mctlassExMatrixLayoutCreate(&matLayoutA, DataType_, m, k, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(&matLayoutB, WeightType_, k, n, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(&matLayoutC, DataType_, m, n, n);
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// bias: (num_experts, n)
// scale: (num, n)
mctlassExDesc_t mctlass_desc;
mctlassExCreateDesc(&mctlass_desc);
mctlassExDataType input_type = DataType_;
mctlassExDataType scale_type = WeightType_;
mctlassExDataType compute_type =
mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
mctlassExEpilogueType epilogue_type =
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
}
// set scale
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale,
sizeof(ptrScale));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type,
sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias,
sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type,
sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type,
sizeof(mctlassExEpilogueType));
const mctlassExContiguousGroupedGemmAlgo_t algo =
mctlassExContiguousGroupedGemmAlgo_t::
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
mctlassExContiguousGroupedDescCreate(
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
int blocksizeM;
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
&blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
contiguous_group_desc,
numExperts,
blocksizeM,
stream);
mctlassExContiguousGroupedGemmBasic(handle,
mctlass_desc,
ptrA,
matLayoutA,
ptrB,
matLayoutB,
ptrC,
matLayoutC,
contiguous_group_desc,
&algo,
nullptr,
0,
stream);
mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
mctlassExMatrixLayoutDestroy(matLayoutB);
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
mctlassExDestroyDesc(mctlass_desc);
}
};
template class McMoeGemmRunner<maca_bfloat16, int8_t>;
} // namespace phi

View File

@@ -14,14 +14,17 @@
#pragma once
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "fused_moe_gemm_kernels.h"
#include "fused_moe_imp_op.h"
#include "fused_moe_op.h"
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
using namespace phi;
namespace phi {
template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
__global__ void moe_token_type_ids_kernel(T* gating_output,
const int* moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k) {
@@ -40,8 +43,8 @@ __global__ void moe_token_type_ids_kernel(T *gating_output,
}
template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
void moe_token_type_ids_kernelLauncher(T* gating_output,
const int* moe_token_type_ids_out,
const int num_rows,
const int num_experts,
const int k,
@@ -51,3 +54,338 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}
template <typename T, typename MacaType>
class McMoeHelper {
public:
McMoeHelper(const std::string gemm_method,
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner)
: gemm_method_(gemm_method),
int8_moe_gemm_runner_(int8_moe_gemm_runner) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
total_ws_bytes += num_softmax_outs * sizeof(float);
return total_ws_bytes;
}
void computeFFN(const paddle::Tensor* input,
const paddle::Tensor* gate_weight,
const paddle::Tensor* up_gate_proj_weight,
const paddle::Tensor* up_gate_proj_scale,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* down_proj_weight,
const paddle::Tensor* down_proj_scale,
const paddle::Tensor* down_proj_bias,
const paddle::Tensor* moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor* output) {
auto* input_activations = input->data<T>();
auto* gating_weights = gate_weight->data<float>();
const T* fc1_expert_biases =
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T* fc2_expert_biases =
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
auto* output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
auto input_dims = input->dims();
auto up_gate_proj_dims = up_gate_proj_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
const int64_t hidden_size = up_gate_proj_dims[2];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim =
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
} else {
inter_dim = up_gate_proj_dims[1];
}
// if (gemm_method_ == "weight_only_int4") {
// inter_dim = inter_dim * 2;
// }
const int64_t inter_size = inter_dim;
const int64_t num_experts = up_gate_proj_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
// Pointers
int* expert_for_source_row;
int* source_rows_;
int* permuted_rows_;
int* permuted_experts_;
int* expanded_source_row_to_expanded_dest_row;
T* permuted_data_;
int32_t* total_rows_before_expert_;
T* fc1_result_;
float* softmax_out_;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T*>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float* expert_scales_float = expert_scales_float_tensor.data<float>();
float* softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T* fc1_out = fc1_out_tensor.data<T>();
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float* gating_output = gate_tensor.data<float>();
if (moe_token_type_ids) {
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
}
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
nullptr,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
nullptr,
nullptr,
expanded_source_row_to_expanded_dest_row,
num_rows,
num_rows,
hidden_size,
k,
stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
auto m_num_tile =
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
if (gemm_method_ == "weight_only_int8") {
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
reinterpret_cast<const MacaType*>(permuted_data_),
row_major,
reinterpret_cast<const int8_t*>(up_gate_proj_weight->data<int8_t>()),
column_major,
reinterpret_cast<const MacaType*>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const MacaType*>(fc1_expert_biases),
reinterpret_cast<MacaType*>(fc1_out),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_size,
hidden_size,
stream);
} else {
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
}
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T* fc2_result = fc2_output_tensor.data<T>();
if (gemm_method_ == "weight_only_int8") {
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
reinterpret_cast<const MacaType*>(act_out),
row_major,
reinterpret_cast<const int8_t*>(down_proj_weight->data<int8_t>()),
column_major,
reinterpret_cast<const MacaType*>(down_proj_scale->data<T>()),
nullptr,
reinterpret_cast<MacaType*>(fc2_result),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
stream);
} else {
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
}
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
}
}
private:
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner_;
std::string gemm_method_;
CubKeyValueSorter sorter_;
};
} // namespace phi

View File

@@ -20,6 +20,8 @@
#include <string>
#include "cub/cub.cuh"
namespace phi {
static const float HALF_FLT_MAX = 65504.F;
static const float HALF_FLT_MIN = -65504.F;
static inline size_t AlignTo16(const size_t& input) {
@@ -121,3 +123,5 @@ class CubKeyValueSorter {
int num_experts_;
int num_bits_;
};
} // namespace phi

File diff suppressed because it is too large Load Diff

View File

@@ -1,486 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fused_moe_helper.h"
#include "mctlass/numeric_conversion.h"
#include "mctlassEx/mctlassEx.h"
template <typename ElementA, typename ElementB, typename ElementC>
void mc_grouped_gemm_basic_kernel(const ElementA* ptrA,
mctlassExOrder_t majorA,
const ElementB* ptrB,
mctlassExOrder_t majorB,
const ElementA* ptrScale,
const ElementA* ptrBias,
ElementC* ptrC,
mctlassExOrder_t majorC,
const int* ptrSegInd,
int* ptrMNumTilesInd,
int numExperts,
int m, // expanded_active_expert_rows
int n, // inter_dim
int k, // hidden_size
mcStream_t stream) {
mctlassExHandle_t handle;
mctlassExHandleCreate(&handle);
mctlassExMatrixLayout_t matLayoutA;
mctlassExMatrixLayout_t matLayoutB;
mctlassExMatrixLayout_t matLayoutC;
// mat A: (m, k)
mctlassExMatrixLayoutCreate(
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorA,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutA,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat B: (num_experts, n, k)
mctlassExMatrixLayoutCreate(
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorB,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutB,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// mat C: (m, n)
mctlassExMatrixLayoutCreate(
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
&majorC,
sizeof(mctlassExOrder_t));
mctlassExMatrixLayoutSetAttribute(
matLayoutC,
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
&numExperts,
sizeof(int));
// bias: (num_experts, n)
// scale: (num, n)
mctlassExDesc_t mctlass_desc;
mctlassExCreateDesc(&mctlass_desc);
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
mctlassExEpilogueType epilogue_type =
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
if (ptrBias) {
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
}
// set scale
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
&ptrScale,
sizeof(ptrScale));
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
&input_type,
sizeof(mctlassExDataType));
// set bias
if (ptrBias) {
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
&ptrBias,
sizeof(ptrBias));
}
// set coumpute type
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
&compute_type,
sizeof(mctlassExDataType));
// set epilogue type
mctlassExDescSetAttribute(
mctlass_desc,
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
&epilogue_type,
sizeof(mctlassExEpilogueType));
const mctlassExContiguousGroupedGemmAlgo_t algo =
mctlassExContiguousGroupedGemmAlgo_t::
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
mctlassExContiguousGroupedDescCreate(
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
int blocksizeM;
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
&blocksizeM);
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
mctlass_desc,
matLayoutA,
matLayoutB,
matLayoutC,
&algo,
contiguous_group_desc,
numExperts,
blocksizeM,
stream);
mctlassExContiguousGroupedGemmBasic(handle,
mctlass_desc,
ptrA,
matLayoutA,
ptrB,
matLayoutB,
ptrC,
matLayoutC,
contiguous_group_desc,
&algo,
nullptr,
0,
stream);
mctlassExHandleDestroy(handle);
mctlassExMatrixLayoutDestroy(matLayoutA);
mctlassExMatrixLayoutDestroy(matLayoutB);
mctlassExMatrixLayoutDestroy(matLayoutC);
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
mctlassExDestroyDesc(mctlass_desc);
}
template <typename T, typename ElementA, typename ElementB, typename ElementC>
class McMoeHelper {
public:
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows,
const int64_t hidden_size,
const int64_t inter_size,
const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
total_ws_bytes += num_softmax_outs * sizeof(float);
return total_ws_bytes;
}
void computeFFN(const paddle::Tensor* input,
const paddle::Tensor* gate_weight,
const paddle::Tensor* ffn1_weight,
const paddle::Tensor* ffn1_scale,
const paddle::Tensor* ffn1_bias,
const paddle::Tensor* ffn2_weight,
const paddle::Tensor* ffn2_scale,
const paddle::Tensor* ffn2_bias,
const paddle::Tensor* moe_token_type_ids,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
const float routed_scaling_factor,
const std::string moe_type,
paddle::Tensor* output) {
auto* input_activations = input->data<T>();
auto* gating_weights = gate_weight->data<float>();
const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
auto* output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
auto input_dims = input->dims();
auto ffn1_dims = ffn1_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
const int64_t hidden_size = ffn1_dims[2];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
} else {
inter_dim = ffn1_dims[1];
}
// if (gemm_method == "weight_only_int4") {
// inter_dim = inter_dim * 2;
// }
const int64_t inter_size = inter_dim;
const int64_t num_experts = ffn1_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
// Pointers
int* expert_for_source_row;
int* source_rows_;
int* permuted_rows_;
int* permuted_experts_;
int* expanded_source_row_to_expanded_dest_row;
T* permuted_data_;
int32_t* total_rows_before_expert_;
T* fc1_result_;
float* softmax_out_;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T*>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float* expert_scales_float = expert_scales_float_tensor.data<float>();
float* softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T* fc1_out = fc1_out_tensor.data<T>();
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float* gating_output = gate_tensor.data<float>();
if (moe_token_type_ids) {
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out,
num_rows,
num_experts,
k,
stream);
}
topk_gating_softmax_kernelLauncher<float>(gating_output,
expert_scales_float,
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
num_experts,
k,
group_moe,
stream);
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_,
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
source_rows_,
permuted_rows_,
k * num_rows,
false,
stream);
initialize_moe_routing_kernelLauncher(
input_activations,
permuted_data_,
permuted_rows_,
expanded_source_row_to_expanded_dest_row,
num_rows,
num_rows,
hidden_size,
k,
stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows,
num_experts,
total_rows_before_expert_,
stream);
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
mctlassExOrder_t column_major =
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
auto m_num_tile =
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA*>(permuted_data_),
row_major,
reinterpret_cast<const ElementB*>(ffn1_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(ffn1_scale->data<T>()),
reinterpret_cast<const ElementA*>(fc1_expert_biases),
reinterpret_cast<ElementC*>(fc1_out),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_size,
hidden_size,
stream);
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T* fc2_result = fc2_output_tensor.data<T>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA*>(act_out),
row_major,
reinterpret_cast<const ElementB*>(ffn2_weight->data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(ffn2_scale->data<T>()),
nullptr,
reinterpret_cast<ElementC*>(fc2_result),
row_major,
total_rows_before_expert_,
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_size / 2,
stream);
finalize_moe_routing_kernelLauncher(
fc2_result,
output_,
fc2_expert_biases,
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
hidden_size,
k,
static_cast<int>(1),
norm_topk_prob,
routed_scaling_factor,
stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out,
output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float*>(expert_scales_float),
expanded_source_row_to_expanded_dest_row,
expert_for_source_row,
num_rows,
inter_size,
k,
static_cast<int>(0),
norm_topk_prob,
routed_scaling_factor,
stream);
}
}
private:
std::string gemm_method_;
CubKeyValueSorter sorter_;
};

View File

@@ -17,26 +17,35 @@
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma once
#include "fused_moe_helper.h"
#include "fused_moe_imp_op.h"
#include "fused_moe_op.h"
#pragma GCC diagnostic pop
#include "helper.h"
template <paddle::DataType T>
void MoeDispatchKernel(const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* top_k_weight,
paddle::Tensor* top_k_indices) {
void MoeDispatchKernel(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const bool topk_only_mode,
const int num_rows,
const int hidden_size,
const int expert_num,
paddle::Tensor* permute_input,
paddle::Tensor* tokens_expert_prefix_sum,
paddle::Tensor* permute_indices_per_token,
paddle::Tensor* topk_weight,
paddle::Tensor* topk_idx,
paddle::Tensor* expert_idx_per_token) {
using namespace phi;
if (num_rows == 0) {
return;
}
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -78,7 +87,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
int* expert_for_source_row = top_k_indices->data<int>();
int* topk_idx_ptr = topk_idx->data<int>();
float* softmax_max_prob = nullptr;
if (group_moe) {
@@ -103,23 +112,25 @@ void MoeDispatchKernel(const paddle::Tensor& input,
softmax_out_ = nullptr;
}
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
top_k_weight->data<float>(),
softmax_out_,
expert_for_source_row,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
topk_gating_softmax_kernelLauncher(
gating_output.data<float>(),
static_cast<const float*>(nullptr), // no gating_correction_bias
topk_weight->data<float>(),
softmax_out_,
topk_idx_ptr,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
group_moe,
stream,
topk_only_mode);
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
sorter_ws_size_bytes,
expert_for_source_row,
permuted_experts_,
topk_idx_ptr,
expert_idx_per_token->data<int32_t>(),
source_rows_,
permuted_rows_,
moe_topk * num_rows,
@@ -130,6 +141,8 @@ void MoeDispatchKernel(const paddle::Tensor& input,
input.data<data_t>(),
permute_input->data<data_t>(),
permuted_rows_,
expert_idx_per_token->data<int32_t>(),
nullptr,
permute_indices_per_token->data<int32_t>(),
num_rows,
num_rows,
@@ -137,7 +150,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
moe_topk,
stream);
compute_total_rows_before_expert(permuted_experts_,
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
moe_topk * num_rows,
expert_num,
tokens_expert_prefix_sum->data<int32_t>(),
@@ -147,8 +160,11 @@ void MoeDispatchKernel(const paddle::Tensor& input,
std::vector<paddle::Tensor> MoeExpertDispatch(
const paddle::Tensor& input,
const paddle::Tensor& gating_output,
const paddle::optional<paddle::Tensor>& gating_correction_bias,
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
const int moe_topk,
const bool group_moe,
const std::string& moe_quant_type,
const bool topk_only_mode) {
const auto input_type = input.dtype();
auto place = input.place();
@@ -168,9 +184,9 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_input =
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
// correspond to the weighted coefficients of the results from each expert.
auto top_k_weight =
auto topk_weight =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
auto top_k_indices =
auto topk_idx =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
auto tokens_expert_prefix_sum =
@@ -178,18 +194,24 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
auto expert_idx_per_token =
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
if (token_rows == 0) {
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices};
topk_weight,
topk_idx,
expert_idx_per_token};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
gating_output,
gating_correction_bias,
w4a8_in_scale,
moe_topk,
group_moe,
topk_only_mode,
@@ -199,37 +221,25 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
&permute_input,
&tokens_expert_prefix_sum,
&permute_indices_per_token,
&top_k_weight,
&top_k_indices);
&topk_weight,
&topk_idx,
&expert_idx_per_token);
break;
// case paddle::DataType::FLOAT16:
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
// gating_output,
// moe_topk,
// group_moe,
// topk_only_mode,
// num_rows,
// hidden_size,
// expert_num,
// &permute_input,
// &tokens_expert_prefix_sum,
// &permute_indices_per_token,
// &top_k_weight,
// &top_k_indices);
// break;
default:
PD_THROW("Only support bf16 for MoeDispatchKernel");
PD_THROW("Unsupported data type for MoeDispatchKernel");
}
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices};
topk_weight,
topk_idx,
expert_idx_per_token};
}
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gating_output_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const int moe_topk) {
int token_rows = -1;
@@ -241,33 +251,44 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
const int num_rows = token_rows;
const int hidden_size = input_shape[input_shape.size() - 1];
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
return {{moe_topk * num_rows, hidden_size},
return {{permuted_rows, hidden_size},
{expert_num},
{moe_topk, num_rows},
{num_rows, moe_topk},
{num_rows, moe_topk}};
{num_rows, moe_topk},
{permuted_rows}};
}
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gating_output_dtype,
const paddle::optional<paddle::DataType>& bias_type,
const int moe_topk) {
return {input_dtype,
paddle::DataType::INT64,
paddle::DataType::INT32,
paddle::DataType::FLOAT32,
paddle::DataType::INT32,
paddle::DataType::INT32};
}
PD_BUILD_OP(moe_expert_dispatch)
.Inputs({"input", "gating_output"})
PD_BUILD_STATIC_OP(moe_expert_dispatch)
.Inputs({"input",
"gating_output",
paddle::Optional("gating_correction_bias"),
paddle::Optional("w4a8_in_scale")})
.Outputs({"permute_input",
"tokens_expert_prefix_sum",
"permute_indices_per_token",
"top_k_weight",
"top_k_indices"})
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
"topk_weight",
"topk_idx",
"expert_idx_per_token"})
.Attrs({"moe_topk:int",
"group_moe:bool",
"moe_quant_type:std::string",
"topk_only_mode:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));

View File

@@ -12,23 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// BUILD_MARK
#pragma once
#include "fused_moe_helper.h"
#include "helper.h"
#include "mc_fused_moe_helper.h"
template <paddle::DataType T,
typename ElementA,
typename ElementB,
typename ElementC>
void McMoeFFNKernel(paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const std::string& quant_method) {
template <paddle::DataType T>
void MoeFFNKernel(paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -38,11 +36,13 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
auto input_type = permute_input.dtype();
auto stream = permute_input.stream();
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
const int expanded_active_expert_rows =
permute_input.dims()[0]; // permute_input.dims(): m, k
const int num_experts = ffn1_weight.dims()[0]; // batchsize
const int hidden_size = ffn1_weight.dims()[2]; // n
int inter_dim = ffn1_weight.dims()[1]; // k
permute_input.dims()[0]; // permute_input.dims(): m, k
const int num_experts = up_gate_proj_weight.dims()[0]; // batchsize
const int hidden_size = up_gate_proj_weight.dims()[2]; // n
int inter_dim = up_gate_proj_weight.dims()[1]; // k
const int64_t inter_size = inter_dim; // since weight_only_int_8
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
@@ -58,60 +58,71 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
// ffn1
auto fc1_expert_biases =
ffn1_bias
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
up_gate_proj_bias
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())
->data<data_t>()
: nullptr;
auto fc1_expert_scales =
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA*>(permuted_input_ptr),
row_major,
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(fc1_expert_scales),
reinterpret_cast<const ElementA*>(fc1_expert_biases),
reinterpret_cast<ElementC*>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_dim,
hidden_size,
stream);
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())->data<data_t>();
if (quant_method == "weight_only_int8") {
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
reinterpret_cast<const DataType_*>(permuted_input_ptr),
row_major,
reinterpret_cast<const int8_t*>(up_gate_proj_weight.data<int8_t>()),
column_major,
reinterpret_cast<const DataType_*>(fc1_expert_scales),
reinterpret_cast<const DataType_*>(fc1_expert_biases),
reinterpret_cast<DataType_*>(fc1_out_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
inter_dim,
hidden_size,
stream);
} else {
throw std::runtime_error("Unsupported gemm method: " + quant_method);
}
// swiglu
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<data_t>();
auto fc2_expert_scales =
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
reinterpret_cast<const ElementA*>(act_out),
row_major,
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
column_major,
reinterpret_cast<const ElementA*>(fc2_expert_scales),
nullptr,
reinterpret_cast<ElementC*>(permuted_input_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_dim / 2,
stream);
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<data_t>();
if (quant_method == "weight_only_int8") {
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
reinterpret_cast<const DataType_*>(act_out),
row_major,
reinterpret_cast<const int8_t*>(down_proj_weight.data<int8_t>()),
column_major,
reinterpret_cast<const DataType_*>(fc2_expert_scales),
nullptr,
reinterpret_cast<DataType_*>(permuted_input_ptr),
row_major,
tokens_expert_prefix_sum.data<int>(),
m_num_tile_ptr,
num_experts,
expanded_active_expert_rows,
hidden_size,
inter_dim / 2,
stream);
} else {
throw std::runtime_error("Unsupported gemm method: " + quant_method);
}
}
std::vector<paddle::Tensor> MoeExpertFFN(
paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& ffn1_weight,
const paddle::Tensor& ffn2_weight,
const paddle::optional<paddle::Tensor>& ffn1_bias,
const paddle::optional<paddle::Tensor>& ffn1_scale,
const paddle::optional<paddle::Tensor>& ffn2_scale,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
const std::string& quant_method) {
assert(quant_method == "weight_only_int8");
const auto input_type = permute_input.dtype();
@@ -122,31 +133,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
switch (input_type) {
case paddle::DataType::BFLOAT16:
McMoeFFNKernel<paddle::DataType::BFLOAT16,
maca_bfloat16,
int8_t,
maca_bfloat16>(permute_input,
tokens_expert_prefix_sum,
ffn1_weight,
ffn2_weight,
ffn1_bias,
ffn1_scale,
ffn2_scale,
quant_method);
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
expert_idx_per_token,
quant_method);
break;
// case paddle::DataType::FLOAT16:
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
// tokens_expert_prefix_sum,
// ffn1_weight,
// ffn2_weight,
// ffn1_bias,
// ffn1_scale,
// ffn2_scale,
// quant_method,
// ffn_out);
// break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
PD_THROW("Unsupported data type for MoeFFNhKernel");
}
return {permute_input};
}
@@ -154,33 +152,37 @@ std::vector<paddle::Tensor> MoeExpertFFN(
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& ffn1_weight_shape,
const std::vector<int64_t>& ffn2_weight_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
const std::string& quant_method) {
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
const paddle::DataType& permute_input_dtype,
const paddle::DataType& tokens_expert_prefix_sum_dtype,
const paddle::DataType& ffn1_weight_dtype,
const paddle::DataType& ffn2_weight_dtype,
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
const paddle::optional<paddle::DataType>& expert_idx_per_token_dtype) {
return {permute_input_dtype};
}
PD_BUILD_OP(moe_expert_ffn)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"ffn1_weight",
"ffn2_weight",
paddle::Optional("ffn1_bias"),
paddle::Optional("ffn1_scale"),
paddle::Optional("ffn2_scale")})
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("expert_idx_per_token")})
.Outputs({"output_tensor"})
.Attrs({"quant_method:std::string"})
.SetKernelFn(PD_KERNEL(MoeExpertFFN))

View File

@@ -14,7 +14,6 @@
#pragma once
#include "fused_moe_helper.h"
#include "fused_moe_op.h"
#include "helper.h"
@@ -23,13 +22,14 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor,
const int num_rows,
const int hidden_size,
const int topk,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
finalize_moe_routing_kernelLauncher(
ffn_out.data<data_t>(),
output->data<data_t>(),
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
top_k_weight.data<float>(),
permute_indices_per_token.data<int32_t>(),
top_k_indices.data<int>(),
@@ -56,7 +56,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
const paddle::Tensor& top_k_weight,
const paddle::Tensor& permute_indices_per_token,
const paddle::Tensor& top_k_indices,
const paddle::optional<paddle::Tensor>& ffn2_bias,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const bool norm_topk_prob,
const float routed_scaling_factor) {
const auto input_type = ffn_out.dtype();
@@ -69,7 +69,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
// Avoids invalid configuration argument when we launch the kernel.
if (ffn_out.dims()[0] == 0) return {output};
if (num_rows == 0) return {output};
switch (input_type) {
case paddle::DataType::BFLOAT16:
@@ -77,7 +77,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
top_k_weight,
permute_indices_per_token,
top_k_indices,
ffn2_bias,
down_proj_bias,
norm_topk_prob,
routed_scaling_factor,
num_rows,
@@ -85,21 +85,8 @@ std::vector<paddle::Tensor> MoeExpertReduce(
topk,
&output);
break;
// case paddle::DataType::FLOAT16:
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
// top_k_weight,
// permute_indices_per_token,
// top_k_indices,
// ffn2_bias,
// norm_topk_prob,
// routed_scaling_factor,
// num_rows,
// hidden_size,
// topk,
// &output);
// break;
default:
PD_THROW("Only support bf16 for MoeDispatchKernel");
PD_THROW("Unsupported data type for MoeReduceKernel");
}
return {output};
}
@@ -109,7 +96,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
const std::vector<int64_t>& top_k_weight_shape,
const std::vector<int64_t>& permute_indices_per_token_shape,
const std::vector<int64_t>& top_k_indices_shape,
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
const int topk = top_k_indices_shape[1];
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
ffn_out_shape[1]};
@@ -122,7 +109,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
const paddle::DataType& top_k_weight_dtype,
const paddle::DataType& permute_indices_per_token_dtype,
const paddle::DataType& top_k_indices_dtype,
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
return {ffn_out_dtype};
}
@@ -131,7 +118,7 @@ PD_BUILD_OP(moe_expert_reduce)
"top_k_weight",
"permute_indices_per_token",
"top_k_indices",
paddle::Optional("ffn2_bias")})
paddle::Optional("down_proj_bias")})
.Outputs({"output"})
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
.SetKernelFn(PD_KERNEL(MoeExpertReduce))

View File

@@ -627,11 +627,17 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
"gpu_ops/moe/tritonmoe_preprocess.cu",
"gpu_ops/moe/moe_topk_select.cu",
"gpu_ops/get_img_boundaries.cc",
"gpu_ops/remote_cache_kv_ipc.cc",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"metax_ops/moe_dispatch.cu",
"metax_ops/moe_ffn.cu",
"metax_ops/moe_reduce.cu",
"metax_ops/fused_moe.cu",
"metax_ops/apply_rope.cu",
"metax_ops/apply_rope_qkv.cu",
"metax_ops/cache_kv_with_rope.cu",
]
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
@@ -657,6 +663,11 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
os.path.join(maca_path, "include"),
os.path.join(maca_path, "include/mcr"),
os.path.join(maca_path, "include/common"),
os.path.join(maca_path, "include/mcfft"),
os.path.join(maca_path, "include/mcrand"),
os.path.join(maca_path, "include/mcsparse"),
os.path.join(maca_path, "include/mcblas"),
os.path.join(maca_path, "include/mcsolver"),
],
),
)