mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] refactor cutlass moe and optimize flash attention (#5361)
* [Metax] refactor moe and flash attention backend --------- Co-authored-by: zhangchenyi_dl <16219492+zhangchenyidl@user.noreply.gitee.com>
This commit is contained in:
@@ -48,14 +48,15 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
|
||||
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||
#endif
|
||||
|
||||
template <typename T> struct Pair {
|
||||
template <typename T>
|
||||
struct Pair {
|
||||
T value;
|
||||
int count;
|
||||
|
||||
__device__ Pair operator+(const Pair &other) const {
|
||||
__device__ Pair operator+(const Pair& other) const {
|
||||
return {value + other.value, count + other.count};
|
||||
}
|
||||
__device__ Pair &operator+=(const Pair &other) {
|
||||
__device__ Pair& operator+=(const Pair& other) {
|
||||
value += other.value;
|
||||
count += other.count;
|
||||
return *this;
|
||||
@@ -78,22 +79,25 @@ struct ValueCount {
|
||||
};
|
||||
|
||||
struct BoolDiffOp {
|
||||
__device__ __forceinline__ bool operator()(const bool &lhs,
|
||||
const bool &rhs) const {
|
||||
__device__ __forceinline__ bool operator()(const bool& lhs,
|
||||
const bool& rhs) const {
|
||||
return lhs != rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct SamplingTempStorage {
|
||||
union {
|
||||
float deterministic_scan[BLOCK_THREADS / 32];
|
||||
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||
TempStorage reduce_value_count;
|
||||
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
|
||||
} block_prim;
|
||||
struct {
|
||||
@@ -112,14 +116,17 @@ struct SamplingTempStorage {
|
||||
* algorithm. \note This implementation is slower than the cub::BlockScan, but
|
||||
* it is deterministic.
|
||||
*/
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
|
||||
__device__ __forceinline__ void
|
||||
DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM> *temp_storage) {
|
||||
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename T>
|
||||
__device__ __forceinline__ void DeterministicInclusiveSum(
|
||||
const T* in_data,
|
||||
T* out_data,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||
temp_storage) {
|
||||
T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
||||
T thread_data[VEC_SIZE];
|
||||
T thread_sum = 0;
|
||||
#pragma unroll
|
||||
@@ -138,8 +145,8 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
}
|
||||
}
|
||||
|
||||
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
|
||||
threadIdx.x | 0xffffffff);
|
||||
T warp_sum = __shfl_sync(
|
||||
0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff);
|
||||
if (threadIdx.x % 32 == 31) {
|
||||
thread_exclusive_prefix_sum = 0;
|
||||
}
|
||||
@@ -197,12 +204,21 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
bool DETERMINISTIC,
|
||||
typename Predicate>
|
||||
__device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
|
||||
uint32_t i,
|
||||
uint32_t d,
|
||||
Predicate pred,
|
||||
float u,
|
||||
vec_t<float, VEC_SIZE> prob_vec,
|
||||
float& aggregate,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||
temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
float prob_greater_than_threshold[VEC_SIZE];
|
||||
float inclusive_cdf[VEC_SIZE];
|
||||
@@ -212,14 +228,14 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
||||
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
||||
}
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
float aggregate_local =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||
.Sum(prob_greater_than_threshold);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce)
|
||||
.Sum(prob_greater_than_threshold);
|
||||
#else
|
||||
float aggregate_local =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage->block_aggregate.value = aggregate_local;
|
||||
@@ -229,14 +245,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
|
||||
if (aggregate + aggregate_local > u) {
|
||||
if constexpr (DETERMINISTIC) {
|
||||
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
|
||||
DeterministicInclusiveSum<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM>(
|
||||
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
||||
} else {
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||
temp_storage->block_prim.scan)
|
||||
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
|
||||
#else
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||
temp_storage->block_prim.scan)
|
||||
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
||||
#endif
|
||||
|
||||
@@ -250,28 +271,35 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
|
||||
bool greater_than_u_diff[VEC_SIZE];
|
||||
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#endif
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#else
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#endif
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.SubtractLeft<VEC_SIZE>(
|
||||
greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||
#endif
|
||||
#else
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#else
|
||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||
temp_storage->block_prim.adj_diff)
|
||||
.FlagHeads<VEC_SIZE>(
|
||||
greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||
#endif
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
if (greater_than_u_diff[j]) {
|
||||
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
||||
atomicMin(&(temp_storage->sampled_id),
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -287,9 +315,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
valid_index[j] = -1;
|
||||
}
|
||||
}
|
||||
int max_valid_index =
|
||||
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
|
||||
.Reduce(valid_index, cub::Max());
|
||||
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage->block_prim.reduce_int)
|
||||
.Reduce(valid_index, cub::Max());
|
||||
if (tx == 0 && max_valid_index != -1) {
|
||||
temp_storage->last_valid_id = max_valid_index;
|
||||
}
|
||||
@@ -297,15 +325,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
||||
aggregate += aggregate_local;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, IdType* top_k_arr,
|
||||
uint32_t d, uint64_t philox_seed,
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs,
|
||||
IdType* output,
|
||||
float* top_p_arr,
|
||||
IdType* top_k_arr,
|
||||
uint32_t d,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
@@ -315,12 +347,12 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
const float p = top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
float aggregate;
|
||||
@@ -336,12 +368,22 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
||||
DeviceSamplingFromProb<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM,
|
||||
DETERMINISTIC>(
|
||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
||||
i,
|
||||
d,
|
||||
[&](float x) { return x > low; },
|
||||
u,
|
||||
probs_vec,
|
||||
aggregate,
|
||||
&temp_storage);
|
||||
if (aggregate > u) {
|
||||
break;
|
||||
}
|
||||
@@ -362,28 +404,29 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_0[j] = {(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1[j] = {(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#else
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
||||
@@ -391,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#else
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
||||
@@ -427,14 +470,19 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC, typename DType, typename IdType>
|
||||
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
float* top_p_arr, uint32_t d,
|
||||
uint64_t philox_seed, uint64_t philox_offset) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopPSamplingFromProbKernel(DType* probs,
|
||||
IdType* output,
|
||||
float* top_p_arr,
|
||||
uint32_t d,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset) {
|
||||
const uint32_t batch_size = gridDim.x;
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
curandStatePhilox4_32_10_t state;
|
||||
@@ -442,12 +490,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
const uint32_t row_idx = bx;
|
||||
float top_p = top_p_arr[row_idx];
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
float aggregate;
|
||||
@@ -463,12 +511,22 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
||||
DeviceSamplingFromProb<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
SCAN_ALGORITHM,
|
||||
REDUCE_ALGORITHM,
|
||||
DETERMINISTIC>(
|
||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
||||
i,
|
||||
d,
|
||||
[&](float x) { return x > low; },
|
||||
u,
|
||||
probs_vec,
|
||||
aggregate,
|
||||
&temp_storage);
|
||||
if (aggregate > u) {
|
||||
break;
|
||||
}
|
||||
@@ -489,7 +547,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||
@@ -499,12 +558,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_0);
|
||||
#else
|
||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
||||
@@ -512,12 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
__syncthreads();
|
||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum(probs_gt_pivot_1);
|
||||
#else
|
||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||
#endif
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
||||
@@ -546,9 +609,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
template <uint32_t VEC_SIZE,
|
||||
uint32_t BLOCK_THREADS,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
typename TempStorage>
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
||||
__device__ __forceinline__ float GetMaxValue(float* in_data,
|
||||
uint32_t row_idx,
|
||||
uint32_t d,
|
||||
TempStorage& temp_storage) {
|
||||
const uint32_t tx = threadIdx.x;
|
||||
vec_t<float, VEC_SIZE> in_data_vec;
|
||||
@@ -557,21 +624,24 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
in_data_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
in_data_vec.cast_load(in_data + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
float in_data_[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
in_data_[j] = in_data_vec[j];
|
||||
}
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(in_data_, cub::Max()));
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(in_data_, cub::Max()));
|
||||
#else
|
||||
max_val = max(
|
||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
max_val = max(max_val,
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||
#endif
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -585,10 +655,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||
struct RenormTempStorage {
|
||||
union {
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_value_count;
|
||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce;
|
||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||
reduce_int;
|
||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||
TempStorage reduce_value_count;
|
||||
} block_prim;
|
||||
struct {
|
||||
float max_val;
|
||||
@@ -607,24 +679,33 @@ struct RenormTempStorage {
|
||||
};
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
||||
typename DType,typename IdType>
|
||||
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
DType* renormed_prob,uint32_t d) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void MinPSamplingFromProbKernel(DType* probs,
|
||||
const float* min_p_arr,
|
||||
DType* renormed_prob,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
|
||||
const uint32_t row_idx = bx;
|
||||
|
||||
extern __shared__ __align__(
|
||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
extern __shared__ __align__(alignof(
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
uint8_t smem_sampling[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
auto& temp_storage = reinterpret_cast<
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||
smem_sampling);
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||
float max_val = GetMaxValue<
|
||||
VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
REDUCE_ALGORITHM,
|
||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
float pivot = max_val * p;
|
||||
|
||||
@@ -633,7 +714,8 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -641,42 +723,51 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
||||
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.store(renormed_prob + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
||||
typename DType, typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
||||
template <uint32_t BLOCK_THREADS,
|
||||
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||
uint32_t VEC_SIZE,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
__global__ void TopKRenormProbKernel(DType* probs,
|
||||
DType* renormed_prob,
|
||||
IdType* top_k_arr,
|
||||
uint32_t d) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
const uint32_t row_idx = bx;
|
||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
#else
|
||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||
#endif
|
||||
vec_t<float, VEC_SIZE> probs_vec;
|
||||
if (k < d) {
|
||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
||||
uint8_t smem_renorm[];
|
||||
extern __shared__ __align__(alignof(
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>)) uint8_t smem_renorm[];
|
||||
auto& temp_storage =
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(
|
||||
smem_renorm);
|
||||
temp_storage.max_val = 0;
|
||||
|
||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
float max_val =
|
||||
GetMaxValue<VEC_SIZE,
|
||||
BLOCK_THREADS,
|
||||
REDUCE_ALGORITHM,
|
||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||
probs, row_idx, d, temp_storage);
|
||||
|
||||
double low = 0, high = max_val;
|
||||
float min_gt_low, max_le_high;
|
||||
float sum_low = 1;
|
||||
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
||||
// loop invariant:
|
||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
|
||||
// p <= high} loop invariant:
|
||||
// - f(low) >= k, f(high) < k
|
||||
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||
// stopping condition: min_gt_low == max_le_high
|
||||
@@ -692,55 +783,65 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE],
|
||||
probs_gt_pivot_1_pair[VEC_SIZE];
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_gt_pivot_0_pair[j] = {
|
||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
(probs_vec[j] > pivot_0 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
probs_gt_pivot_1_pair[j] = {
|
||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
(probs_vec[j] > pivot_1 &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||
|
||||
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
if (probs_vec[j] > low &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||
}
|
||||
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
if (probs_vec[j] <= high &&
|
||||
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||
max_le_high = max(max_le_high, probs_vec[j]);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0_pair);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_0_pair);
|
||||
#else
|
||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
aggregate_gt_pivot_0 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||
#endif
|
||||
__syncthreads();
|
||||
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1_pair);
|
||||
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum(probs_gt_pivot_1_pair);
|
||||
#else
|
||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
aggregate_gt_pivot_1 +=
|
||||
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce_value_count)
|
||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||
#endif
|
||||
__syncthreads();
|
||||
}
|
||||
min_gt_low =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(min_gt_low, cub::Min());
|
||||
__syncthreads();
|
||||
max_le_high =
|
||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||
temp_storage.block_prim.reduce)
|
||||
.Reduce(max_le_high, cub::Max());
|
||||
if (tx == 0) {
|
||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||
@@ -774,23 +875,29 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||
probs_vec.fill(0);
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
|
||||
tx * VEC_SIZE);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||
}
|
||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
probs_vec.store(renormed_prob + row_idx * d +
|
||||
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaError_t TopPSamplingFromProb(T* probs,
|
||||
IdType* output,
|
||||
uint32_t batch_size,
|
||||
const T* top_p_val,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
@@ -799,99 +906,139 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
void* args[] = {
|
||||
&probs, &output, &top_p_val, &d, &philox_seed, &philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE,
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel =
|
||||
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
||||
auto kernel = TopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
||||
smem_size, stream));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
template <typename T,typename IdType>
|
||||
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t MinPSamplingFromProb(T* probs,
|
||||
const T* min_p_arr,
|
||||
T* renormed_prob,
|
||||
uint32_t batch_size,
|
||||
uint32_t d, bool deterministic,
|
||||
cudaStream_t stream = 0){
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
|
||||
void* args[] = {&probs, &min_p_arr, &renormed_prob, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE,
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel =
|
||||
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T,IdType>;
|
||||
auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
||||
smem_size, stream));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename IdType>
|
||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
||||
uint32_t d, bool deterministic,
|
||||
uint64_t philox_seed, uint64_t philox_offset,
|
||||
cudaError_t TopKTopPSamplingFromProb(T* probs,
|
||||
IdType* output,
|
||||
uint32_t batch_size,
|
||||
const T* top_p_val,
|
||||
const IdType* top_k_val,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
||||
&d, &philox_seed, &philox_offset};
|
||||
void* args[] = {&probs,
|
||||
&output,
|
||||
&top_p_val,
|
||||
&top_k_val,
|
||||
&d,
|
||||
&philox_seed,
|
||||
&philox_offset};
|
||||
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(
|
||||
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
vec_size,
|
||||
VEC_SIZE,
|
||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||
SCAN_ALGO,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
T,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
||||
uint32_t batch_size, uint32_t d,
|
||||
cudaError_t TopKRenormProb(DType* probs,
|
||||
DType* renormed_prob,
|
||||
IdType* top_k_arr,
|
||||
uint32_t batch_size,
|
||||
uint32_t d,
|
||||
cudaStream_t stream = 0) {
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
|
||||
auto compute_capacity = GetCudaComputeCapability();
|
||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
const uint32_t smem_size =
|
||||
sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
||||
CUDA_CALL(
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS,
|
||||
REDUCE_ALGO,
|
||||
VEC_SIZE,
|
||||
DType,
|
||||
IdType>;
|
||||
CUDA_CALL(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
CUDA_CALL(cudaLaunchKernel(
|
||||
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
});
|
||||
return cudaSuccess;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
} // namespace sampling
|
||||
|
||||
@@ -23,221 +23,235 @@
|
||||
#include <cuda_device_runtime_api.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <curand.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <curand_philox4x32_x.h>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <curand.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <curand_philox4x32_x.h>
|
||||
|
||||
/******************* utils *******************/
|
||||
#define STR_HELPER(x) #x
|
||||
#define STR(x) STR_HELPER(x)
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
||||
<< ") " << __FILE__ << ": line " << __LINE__ \
|
||||
<< " at function " << STR(func) << std::endl; \
|
||||
return e; \
|
||||
} \
|
||||
#define CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
||||
<< ") " << __FILE__ << ": line " << __LINE__ \
|
||||
<< " at function " << STR(func) << std::endl; \
|
||||
return e; \
|
||||
} \
|
||||
}
|
||||
#else
|
||||
#define CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
return e; \
|
||||
} \
|
||||
#define CUDA_CALL(func, ...) \
|
||||
{ \
|
||||
cudaError_t e = (func); \
|
||||
if (e != cudaSuccess) { \
|
||||
return e; \
|
||||
} \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
|
||||
if (deterministic) { \
|
||||
constexpr bool DETERMINISTIC = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool DETERMINISTIC = false; \
|
||||
__VA_ARGS__ \
|
||||
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
|
||||
if (deterministic) { \
|
||||
constexpr bool DETERMINISTIC = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool DETERMINISTIC = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
|
||||
switch (aligned_vec_size) { \
|
||||
case 16: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 4; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 2; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 1: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
||||
throw std::invalid_argument(err_msg.str()); \
|
||||
} \
|
||||
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
|
||||
switch (aligned_vec_size) { \
|
||||
case 16: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 8: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 8; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 4: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 4; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 2: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 2; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 1: { \
|
||||
constexpr size_t ALIGNED_VEC_SIZE = 1; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
||||
throw std::invalid_argument(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
/******************* vec_t<float> *******************/
|
||||
#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
|
||||
template <typename float_t, size_t vec_size> struct vec_t {
|
||||
SAMPLING_INLINE float_t &operator[](size_t i);
|
||||
SAMPLING_INLINE const float_t &operator[](size_t i) const;
|
||||
template <typename float_t, size_t vec_size>
|
||||
struct vec_t {
|
||||
SAMPLING_INLINE float_t& operator[](size_t i);
|
||||
SAMPLING_INLINE const float_t& operator[](size_t i) const;
|
||||
SAMPLING_INLINE void fill(float_t val);
|
||||
SAMPLING_INLINE void load(const float_t *ptr);
|
||||
SAMPLING_INLINE void store(float_t *ptr) const;
|
||||
SAMPLING_INLINE void load(const float_t* ptr);
|
||||
SAMPLING_INLINE void store(float_t* ptr) const;
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src);
|
||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr);
|
||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const;
|
||||
SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src);
|
||||
SAMPLING_INLINE float_t *ptr();
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src);
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_load(const T* ptr);
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_store(T* ptr) const;
|
||||
SAMPLING_INLINE static void memcpy(float_t* dst, const float_t* src);
|
||||
SAMPLING_INLINE float_t* ptr();
|
||||
};
|
||||
|
||||
// float x 1
|
||||
template <> struct vec_t<float, 1> {
|
||||
template <>
|
||||
struct vec_t<float, 1> {
|
||||
float data;
|
||||
|
||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
|
||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
||||
return ((const float *)(&data))[i];
|
||||
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
||||
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
||||
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||
SAMPLING_INLINE void fill(float val);
|
||||
SAMPLING_INLINE void load(const float *ptr);
|
||||
SAMPLING_INLINE void store(float *ptr) const;
|
||||
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 1> &src) {
|
||||
SAMPLING_INLINE void load(const float* ptr);
|
||||
SAMPLING_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
|
||||
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
|
||||
};
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
|
||||
SAMPLING_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
|
||||
SAMPLING_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
|
||||
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
// float x 2
|
||||
template <> struct vec_t<float, 2> {
|
||||
template <>
|
||||
struct vec_t<float, 2> {
|
||||
float2 data;
|
||||
|
||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
|
||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
||||
return ((const float *)(&data))[i];
|
||||
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
||||
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(&data))[i];
|
||||
}
|
||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
||||
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||
SAMPLING_INLINE void fill(float val);
|
||||
SAMPLING_INLINE void load(const float *ptr);
|
||||
SAMPLING_INLINE void store(float *ptr) const;
|
||||
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 2> &src) {
|
||||
SAMPLING_INLINE void load(const float* ptr);
|
||||
SAMPLING_INLINE void store(float* ptr) const;
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
|
||||
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
|
||||
};
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
|
||||
data = make_float2(val, val);
|
||||
}
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 2>::load(const float *ptr) {
|
||||
data = *((float2 *)ptr);
|
||||
SAMPLING_INLINE void vec_t<float, 2>::load(const float* ptr) {
|
||||
data = *((float2*)ptr);
|
||||
}
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 2>::store(float *ptr) const {
|
||||
*((float2 *)ptr) = data;
|
||||
SAMPLING_INLINE void vec_t<float, 2>::store(float* ptr) const {
|
||||
*((float2*)ptr) = data;
|
||||
}
|
||||
|
||||
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
|
||||
*((float2 *)dst) = *((float2 *)src);
|
||||
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
|
||||
*((float2*)dst) = *((float2*)src);
|
||||
}
|
||||
|
||||
// float x 4 or more
|
||||
template <size_t vec_size> struct vec_t<float, vec_size> {
|
||||
template <size_t vec_size>
|
||||
struct vec_t<float, vec_size> {
|
||||
float4 data[vec_size / 4];
|
||||
|
||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
|
||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
||||
return ((const float *)(data))[i];
|
||||
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
|
||||
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||
return ((const float*)(data))[i];
|
||||
}
|
||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
||||
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||
SAMPLING_INLINE void fill(float val) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
data[i] = make_float4(val, val, val, val);
|
||||
}
|
||||
}
|
||||
SAMPLING_INLINE void load(const float *ptr) {
|
||||
SAMPLING_INLINE void load(const float* ptr) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
data[i] = ((float4 *)ptr)[i];
|
||||
data[i] = ((float4*)ptr)[i];
|
||||
}
|
||||
}
|
||||
SAMPLING_INLINE void store(float *ptr) const {
|
||||
SAMPLING_INLINE void store(float* ptr) const {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
((float4 *)ptr)[i] = data[i];
|
||||
((float4*)ptr)[i] = data[i];
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src) {
|
||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||
cast_from_impl(*this, src);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||
cast_load_impl(*this, ptr);
|
||||
}
|
||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
||||
template <typename T>
|
||||
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||
cast_store_impl(ptr, *this);
|
||||
}
|
||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src) {
|
||||
SAMPLING_INLINE static void memcpy(float* dst, const float* src) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||
((float4 *)dst)[i] = ((float4 *)src)[i];
|
||||
((float4*)dst)[i] = ((float4*)src)[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
||||
SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
|
||||
const src_float_t* src_ptr) {
|
||||
const src_float_t* src_ptr) {
|
||||
if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
|
||||
dst.load(src_ptr);
|
||||
} else {
|
||||
@@ -260,11 +274,16 @@ inline std::pair<int, int> GetCudaComputeCapability() {
|
||||
__forceinline__ __device__ float ptx_rcp(float x) {
|
||||
#ifdef PADDLE_WITH_COREX
|
||||
return __ivcorex_rcpf(x);
|
||||
#else
|
||||
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||
return __frcp_rn(x);
|
||||
#else
|
||||
float y;
|
||||
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
||||
return y;
|
||||
#endif
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <paddle/extension.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
#define THREADS_PER_BLOCK 128
|
||||
|
||||
template <typename T>
|
||||
struct Converter;
|
||||
|
||||
template <>
|
||||
struct Converter<__half> {
|
||||
// __half -> float
|
||||
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||
// float -> __half
|
||||
__device__ static __half from_float(float val) {
|
||||
return __float2half_rn(val);
|
||||
}
|
||||
// int -> __half
|
||||
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<__nv_bfloat16> {
|
||||
// __nv_bfloat16 -> float
|
||||
__device__ static float to_float(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
// float -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_float(float val) {
|
||||
return __float2bfloat16_rn(val);
|
||||
}
|
||||
// int -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_int(int val) {
|
||||
return __int2bfloat16_rn(val);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
||||
const T* rot_cos_ptr,
|
||||
const T* rot_sin_ptr,
|
||||
const int head_num,
|
||||
const int base_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qk_ptr + base_idx, &qk_vec);
|
||||
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
|
||||
VecT cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + base_idx + i) =
|
||||
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
||||
const float* rot_cos_ptr,
|
||||
const float* rot_sin_ptr,
|
||||
const int head_num,
|
||||
const int base_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
using VecF = AlignedVector<float, 4>;
|
||||
auto to_float = [] __device__(T val) -> float {
|
||||
return Converter<T>::to_float(val);
|
||||
};
|
||||
auto from_float = [] __device__(float val) -> T {
|
||||
return Converter<T>::from_float(val);
|
||||
};
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qk_ptr + base_idx, &qk_vec);
|
||||
VecF rot_half_vec = {-to_float(qk_vec[1]),
|
||||
to_float(qk_vec[0]),
|
||||
-to_float(qk_vec[3]),
|
||||
to_float(qk_vec[2])};
|
||||
VecF cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||
rot_half_vec[i] * sin_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// qk and rope have a same type
|
||||
template <typename T>
|
||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
||||
const T* k,
|
||||
const T* rot_cos,
|
||||
const T* rot_sin,
|
||||
const int q_num_elements,
|
||||
const int k_num_elements,
|
||||
const int q_head_num,
|
||||
const int k_head_num,
|
||||
const int head_dim,
|
||||
T* q_out,
|
||||
T* k_out) {
|
||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
||||
int head_dim_idx = idx % head_dim;
|
||||
|
||||
if (idx < q_num_elements) {
|
||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
||||
}
|
||||
|
||||
if (idx < k_num_elements) {
|
||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
||||
}
|
||||
}
|
||||
|
||||
// rope dtype is float32
|
||||
template <typename T>
|
||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
||||
const T* k,
|
||||
const float* rot_cos,
|
||||
const float* rot_sin,
|
||||
const int q_num_elements,
|
||||
const int k_num_elements,
|
||||
const int q_head_num,
|
||||
const int k_head_num,
|
||||
const int head_dim,
|
||||
T* q_out,
|
||||
T* k_out) {
|
||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
||||
int head_dim_idx = idx % head_dim;
|
||||
|
||||
if (idx < q_num_elements) {
|
||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
||||
}
|
||||
|
||||
if (idx < k_num_elements) {
|
||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
void ApplyRopeKernel(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin,
|
||||
paddle::Tensor& q_out,
|
||||
paddle::Tensor& k_out) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const auto q_num_elements = q.numel();
|
||||
const auto k_num_elements = k.numel();
|
||||
const auto q_shape = q.shape();
|
||||
const auto k_shape = k.shape();
|
||||
const auto dims = q_shape.size();
|
||||
const auto q_head_num = q_shape[dims - 2];
|
||||
const auto k_head_num = k_shape[dims - 2];
|
||||
const auto head_dim = q_shape.back();
|
||||
int block_num =
|
||||
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
|
||||
(THREADS_PER_BLOCK * 4);
|
||||
auto stream = q.stream();
|
||||
|
||||
if (q.dtype() == rot_cos.dtype()) {
|
||||
DispatchApplyRopeVec4Kernel<DataType_>
|
||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
|
||||
q_num_elements,
|
||||
k_num_elements,
|
||||
q_head_num,
|
||||
k_head_num,
|
||||
head_dim,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
||||
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||
DispatchApplyRopeVec4Kernel<DataType_>
|
||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
||||
reinterpret_cast<const float*>(rot_cos.data<float>()),
|
||||
reinterpret_cast<const float*>(rot_sin.data<float>()),
|
||||
q_num_elements,
|
||||
k_num_elements,
|
||||
q_head_num,
|
||||
k_head_num,
|
||||
head_dim,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
||||
} else {
|
||||
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
|
||||
const paddle::Tensor& k,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin) {
|
||||
auto q_shape = q.shape();
|
||||
auto cos_shape = rot_cos.shape();
|
||||
|
||||
auto q_out = paddle::empty_like(q);
|
||||
auto k_out = paddle::empty_like(k);
|
||||
|
||||
if (q.numel() == 0 || k.numel() == 0) {
|
||||
return {q_out, k_out};
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
q_shape.back() % 2,
|
||||
0,
|
||||
"The last dimension (head_dim) of qk must be an even number "
|
||||
"for RoPE, but got %d",
|
||||
q_shape.back());
|
||||
PADDLE_ENFORCE_EQ(q_shape.size(),
|
||||
cos_shape.size(),
|
||||
"The shape size of cos mismatches the shape size of q, "
|
||||
"expect %d but got %d",
|
||||
q_shape.size(),
|
||||
cos_shape.size());
|
||||
PADDLE_ENFORCE_EQ(q_shape.back(),
|
||||
cos_shape.back(),
|
||||
"The shape.back() of cos mismatches the shape.back() of q, "
|
||||
"expect %d but got %d",
|
||||
q_shape.back(),
|
||||
cos_shape.back());
|
||||
|
||||
auto input_type = q.dtype();
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
|
||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
ApplyRopeKernel<paddle::DataType::FLOAT16>(
|
||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||
}
|
||||
|
||||
return {q_out, k_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
|
||||
const std::vector<int64_t>& q_shape,
|
||||
const std::vector<int64_t>& k_shape,
|
||||
const std::vector<int64_t>& cos_shape,
|
||||
const std::vector<int64_t>& sin_shape) {
|
||||
return {q_shape, k_shape, cos_shape, sin_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> ApplyRopeInferDtype(
|
||||
const paddle::DataType& q_dtype,
|
||||
const paddle::DataType& k_dtype,
|
||||
const paddle::DataType& cos_dtype,
|
||||
const paddle::DataType& sin_dtype) {
|
||||
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(apply_rope)
|
||||
.Inputs({"q", "k", "rot_cos", "rot_sin"})
|
||||
.Outputs({"q_out", "k_out"})
|
||||
.SetKernelFn(PD_KERNEL(ApplyRope))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));
|
||||
329
custom_ops/metax_ops/apply_rope_qkv.cu
Normal file
329
custom_ops/metax_ops/apply_rope_qkv.cu
Normal file
@@ -0,0 +1,329 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <paddle/extension.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T>
|
||||
struct Converter;
|
||||
|
||||
template <>
|
||||
struct Converter<__half> {
|
||||
// __half -> float
|
||||
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||
// float -> __half
|
||||
__device__ static __half from_float(float val) {
|
||||
return __float2half_rn(val);
|
||||
}
|
||||
// int -> __half
|
||||
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<__nv_bfloat16> {
|
||||
// __nv_bfloat16 -> float
|
||||
__device__ static float to_float(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
// float -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_float(float val) {
|
||||
return __float2bfloat16_rn(val);
|
||||
}
|
||||
// int -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_int(int val) {
|
||||
return __int2bfloat16_rn(val);
|
||||
}
|
||||
};
|
||||
|
||||
struct ApplyRopeQKVParams {
|
||||
int head_dim;
|
||||
int token_stride;
|
||||
int head_stride;
|
||||
int q_stride;
|
||||
int kv_stride;
|
||||
int q_head_offset;
|
||||
int k_head_offset;
|
||||
int v_head_offset;
|
||||
int q_head_num;
|
||||
int kv_head_num;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
|
||||
const T* rot_cos_ptr,
|
||||
const T* rot_sin_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qkv_ptr + load_idx, &qk_vec);
|
||||
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
|
||||
VecT cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + store_idx + i) =
|
||||
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
|
||||
const float* rot_cos_ptr,
|
||||
const float* rot_sin_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
const int rot_base_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
using VecF = AlignedVector<float, 4>;
|
||||
auto to_float = [] __device__(T val) -> float {
|
||||
return Converter<T>::to_float(val);
|
||||
};
|
||||
auto from_float = [] __device__(float val) -> T {
|
||||
return Converter<T>::from_float(val);
|
||||
};
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qkv_ptr + load_idx, &qk_vec);
|
||||
VecF rot_half_vec = {-to_float(qk_vec[1]),
|
||||
to_float(qk_vec[0]),
|
||||
-to_float(qk_vec[3]),
|
||||
to_float(qk_vec[2])};
|
||||
VecF cos_vec, sin_vec;
|
||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
*(out + store_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||
rot_half_vec[i] * sin_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, 4>;
|
||||
VecT v_vec;
|
||||
Load(qkv_ptr + load_idx, &v_vec);
|
||||
Store(v_vec, out + store_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
__global__ void DispatchApplyRopeQKVVec4Kernel(const T* qkv,
|
||||
const WeightType* rot_cos,
|
||||
const WeightType* rot_sin,
|
||||
ApplyRopeQKVParams param,
|
||||
T* q_out,
|
||||
T* k_out,
|
||||
T* v_out) {
|
||||
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * 4;
|
||||
int rot_idx = token_idx * param.head_dim + head_dim_idx;
|
||||
int load_idx, store_idx;
|
||||
|
||||
if (head_idx < param.q_head_num && head_dim_idx < param.head_dim) { // q
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.q_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
store_idx =
|
||||
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
|
||||
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, q_out);
|
||||
}
|
||||
|
||||
if (head_idx < param.kv_head_num && head_dim_idx < param.head_dim) { // kv
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.k_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
store_idx =
|
||||
token_idx * param.kv_stride + head_idx * param.head_dim + head_dim_idx;
|
||||
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, k_out);
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.v_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
StoreValue(qkv, load_idx, store_idx, v_out);
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D>
|
||||
void ApplyRopeQKVKernel(const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin,
|
||||
const int q_head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
paddle::Tensor& q_out,
|
||||
paddle::Tensor& k_out,
|
||||
paddle::Tensor& v_out) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int all_num_elements = qkv.numel();
|
||||
const int all_num_head = q_head_num + 2 * kv_head_num;
|
||||
auto stream = qkv.stream();
|
||||
|
||||
dim3 block_dims(1, 4, 32);
|
||||
dim3 grid_dims(all_num_elements / (all_num_head * head_dim), // token
|
||||
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
|
||||
block_dims.y, // head
|
||||
(head_dim + (block_dims.z * 4) - 1) /
|
||||
(block_dims.z * 4) // dim: load vec4 at a time
|
||||
);
|
||||
|
||||
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
|
||||
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
|
||||
|
||||
ApplyRopeQKVParams param;
|
||||
param.head_dim = head_dim;
|
||||
param.token_stride = all_num_head * head_dim;
|
||||
param.head_stride = head_dim;
|
||||
param.q_stride = q_head_num * head_dim;
|
||||
param.kv_stride = kv_head_num * head_dim;
|
||||
param.q_head_offset = 0;
|
||||
param.k_head_offset = q_head_num;
|
||||
param.v_head_offset = q_head_num + kv_head_num;
|
||||
param.q_head_num = q_head_num;
|
||||
param.kv_head_num = kv_head_num;
|
||||
|
||||
if (qkv.dtype() == rot_cos.dtype()) {
|
||||
DispatchApplyRopeQKVVec4Kernel<DataType_, DataType_>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
|
||||
param,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||
DispatchApplyRopeQKVVec4Kernel<DataType_, float>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||
reinterpret_cast<const float*>(rot_cos.data<float>()),
|
||||
reinterpret_cast<const float*>(rot_sin.data<float>()),
|
||||
param,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||
} else {
|
||||
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ApplyRopeQKV(const paddle::Tensor& qkv,
|
||||
const paddle::Tensor& rot_cos,
|
||||
const paddle::Tensor& rot_sin,
|
||||
const int q_head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim) {
|
||||
auto qkv_shape = qkv.shape();
|
||||
auto token_num = qkv_shape[0];
|
||||
auto place = qkv.place();
|
||||
auto dtype = qkv.dtype();
|
||||
common::DDim q_out_shape, kv_out_shape;
|
||||
if (rot_cos.shape().size() == 3) {
|
||||
q_out_shape = {token_num, q_head_num, head_dim};
|
||||
kv_out_shape = {token_num, kv_head_num, head_dim};
|
||||
} else {
|
||||
q_out_shape = {token_num, 1, q_head_num, head_dim};
|
||||
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
|
||||
}
|
||||
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
|
||||
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||
|
||||
if (token_num == 0) {
|
||||
return {q_out, k_out, v_out};
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(qkv_shape.back(),
|
||||
((q_head_num + 2 * kv_head_num) * head_dim),
|
||||
"The last dimension of qkv [%d] must equal to {(q_head_num "
|
||||
"+ 2 * kv_head_num) * head_dim [%d].",
|
||||
qkv_shape.back(),
|
||||
((q_head_num + 2 * kv_head_num) * head_dim));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
head_dim % 2,
|
||||
0,
|
||||
"The last dimension (head_dim) of qkv must be an even number "
|
||||
"for RoPE, but got %d",
|
||||
head_dim);
|
||||
PADDLE_ENFORCE_EQ(q_out.shape().back(),
|
||||
rot_cos.shape().back(),
|
||||
"The last dimension of cos mismatches that of q, "
|
||||
"expect %d but got %d",
|
||||
q_out.shape().back(),
|
||||
rot_cos.shape().back());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
ApplyRopeQKVKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
rot_cos,
|
||||
rot_sin,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
q_out,
|
||||
k_out,
|
||||
v_out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
ApplyRopeQKVKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
rot_cos,
|
||||
rot_sin,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
q_out,
|
||||
k_out,
|
||||
v_out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||
}
|
||||
|
||||
return {q_out, k_out, v_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ApplyRopeQKVInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& cos_shape,
|
||||
const std::vector<int64_t>& sin_shape) {
|
||||
return {qkv_shape, cos_shape, sin_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> ApplyRopeQKVInferDtype(
|
||||
const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& cos_dtype,
|
||||
const paddle::DataType& sin_dtype) {
|
||||
return {qkv_dtype, cos_dtype, sin_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(apply_rope_qkv)
|
||||
.Inputs({"qkv", "rot_cos", "rot_sin"})
|
||||
.Outputs({"q_out", "k_out", "v_out"})
|
||||
.Attrs({"q_head_num:int", "kv_head_num:int", "head_dim:int"})
|
||||
.SetKernelFn(PD_KERNEL(ApplyRopeQKV))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeQKVInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeQKVInferDtype));
|
||||
477
custom_ops/metax_ops/cache_kv_with_rope.cu
Normal file
477
custom_ops/metax_ops/cache_kv_with_rope.cu
Normal file
@@ -0,0 +1,477 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <paddle/extension.h>
|
||||
#include <algorithm>
|
||||
#include "helper.h"
|
||||
|
||||
template <typename T>
|
||||
struct Converter;
|
||||
|
||||
template <>
|
||||
struct Converter<__half> {
|
||||
// __half -> float
|
||||
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||
// float -> __half
|
||||
__device__ static __half from_float(float val) {
|
||||
return __float2half_rn(val);
|
||||
}
|
||||
// int -> __half
|
||||
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<__nv_bfloat16> {
|
||||
// __nv_bfloat16 -> float
|
||||
__device__ static float to_float(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
// float -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_float(float val) {
|
||||
return __float2bfloat16_rn(val);
|
||||
}
|
||||
// int -> __nv_bfloat16
|
||||
__device__ static __nv_bfloat16 from_int(int val) {
|
||||
return __int2bfloat16_rn(val);
|
||||
}
|
||||
};
|
||||
|
||||
struct CacheKVWithRopeParams {
|
||||
int head_dim;
|
||||
int block_size;
|
||||
int block_num;
|
||||
int cache_stride;
|
||||
int token_stride;
|
||||
int head_stride;
|
||||
int q_stride;
|
||||
int kv_stride;
|
||||
int q_head_offset;
|
||||
int k_head_offset;
|
||||
int v_head_offset;
|
||||
int q_head_num;
|
||||
int kv_head_num;
|
||||
};
|
||||
|
||||
template <typename T, int VecSize = 4, bool WriteCache = true>
|
||||
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
|
||||
const T* rotary_cos_ptr,
|
||||
const T* rotary_sin_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
const int cache_store_idx,
|
||||
const int rot_base_idx,
|
||||
T* caches,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, VecSize>;
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qkv_ptr + load_idx, &qk_vec);
|
||||
VecT rot_half_vec;
|
||||
int flag;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i) {
|
||||
flag = 1 - 2 * (i % 2);
|
||||
rot_half_vec[i] = -qk_vec[i + flag] * Converter<T>::from_int(flag);
|
||||
}
|
||||
VecT cos_vec, sin_vec;
|
||||
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i) {
|
||||
T result = qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||
*(out + store_idx + i) = result;
|
||||
|
||||
if (WriteCache) {
|
||||
*(caches + cache_store_idx + i) = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, bool WriteCache = true>
|
||||
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
|
||||
const float* rotary_cos_ptr,
|
||||
const float* rotary_sin_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
const int cache_store_idx,
|
||||
const int rot_base_idx,
|
||||
T* caches,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, VecSize>;
|
||||
using VecF = AlignedVector<float, VecSize>;
|
||||
auto to_float = [] __device__(T val) -> float {
|
||||
return Converter<T>::to_float(val);
|
||||
};
|
||||
auto from_float = [] __device__(float val) -> T {
|
||||
return Converter<T>::from_float(val);
|
||||
};
|
||||
|
||||
VecT qk_vec;
|
||||
Load(qkv_ptr + load_idx, &qk_vec);
|
||||
VecF rot_half_vec;
|
||||
int flag;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i) {
|
||||
flag = 1 - 2 * (i % 2);
|
||||
rot_half_vec[i] = -to_float(qk_vec[i + flag]) * static_cast<float>(flag);
|
||||
}
|
||||
VecF cos_vec, sin_vec;
|
||||
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
|
||||
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i) {
|
||||
T result = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||
rot_half_vec[i] * sin_vec[i]);
|
||||
*(out + store_idx + i) = result;
|
||||
if (WriteCache) {
|
||||
*(caches + cache_store_idx + i) = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4>
|
||||
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
|
||||
const int load_idx,
|
||||
const int store_idx,
|
||||
const int cache_store_idx,
|
||||
T* caches,
|
||||
T* out) {
|
||||
using VecT = AlignedVector<T, VecSize>;
|
||||
VecT v_vec;
|
||||
Load(qkv_ptr + load_idx, &v_vec);
|
||||
Store(v_vec, out + store_idx);
|
||||
Store(v_vec, caches + cache_store_idx);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, int VecSize>
|
||||
__global__ void DispatchCacheKVWithRopeVecKernel(const T* qkv,
|
||||
T* caches_k,
|
||||
T* caches_v,
|
||||
const int* block_tables,
|
||||
const WeightType* rotary_cos,
|
||||
const WeightType* rotary_sin,
|
||||
const int* cu_seqlens_q,
|
||||
const int* batch_ids_q,
|
||||
CacheKVWithRopeParams param,
|
||||
T* q_out,
|
||||
T* k_out,
|
||||
T* v_out) {
|
||||
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * VecSize;
|
||||
|
||||
int load_idx, store_idx, cache_store_idx;
|
||||
int rot_idx = token_idx * param.head_dim + head_dim_idx;
|
||||
|
||||
const int batch_idx = *(batch_ids_q + token_idx);
|
||||
const int inter_batch_token_offset = token_idx - *(cu_seqlens_q + batch_idx);
|
||||
const int inter_batch_block_idx = inter_batch_token_offset / param.block_size;
|
||||
const int inter_block_offset = inter_batch_token_offset % param.block_size;
|
||||
const int block_idx =
|
||||
*(block_tables + batch_idx * param.block_num + inter_batch_block_idx);
|
||||
|
||||
assert(block_idx != -1);
|
||||
|
||||
if (head_dim_idx < param.head_dim) {
|
||||
if (head_idx < param.q_head_num) { // q
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.q_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
store_idx =
|
||||
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
|
||||
RotateQKVec<T, VecSize, false>(qkv,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
load_idx,
|
||||
store_idx,
|
||||
-1,
|
||||
rot_idx,
|
||||
static_cast<T*>(nullptr),
|
||||
q_out);
|
||||
}
|
||||
|
||||
if (head_idx < param.kv_head_num) { // kv
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.k_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
store_idx = token_idx * param.kv_stride + head_idx * param.head_dim +
|
||||
head_dim_idx;
|
||||
cache_store_idx = block_idx * param.cache_stride +
|
||||
inter_block_offset * param.kv_stride +
|
||||
head_idx * param.head_dim + head_dim_idx;
|
||||
// printf("block_idx: %d inter_block_offset: %d cache_store_idx: %d
|
||||
// param.cache_stride: %d\n", block_idx, inter_block_offset,
|
||||
// cache_store_idx, param.cache_stride);
|
||||
RotateQKVec<T, VecSize, true>(qkv,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
load_idx,
|
||||
store_idx,
|
||||
cache_store_idx,
|
||||
rot_idx,
|
||||
caches_k,
|
||||
k_out);
|
||||
|
||||
load_idx = token_idx * param.token_stride +
|
||||
(head_idx + param.v_head_offset) * param.head_stride +
|
||||
head_dim_idx;
|
||||
StoreValue<T, VecSize>(
|
||||
qkv, load_idx, store_idx, cache_store_idx, caches_v, v_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <paddle::DataType D, int VecSize = 4>
|
||||
void CacheKVWithRopeKernel(
|
||||
const paddle::Tensor& qkv, // token_num, head_num * head_dim
|
||||
paddle::Tensor&
|
||||
caches_k, // max_block_num, block_size, kv_head_num, head_dim
|
||||
paddle::Tensor&
|
||||
caches_v, // max_block_num, block_size, kv_head_num, head_dim
|
||||
const paddle::Tensor& block_tables, // bs, block_num
|
||||
const paddle::Tensor& rotary_cos,
|
||||
const paddle::Tensor& rotary_sin,
|
||||
const paddle::Tensor& cu_seqlens_q, // bs + 1
|
||||
const paddle::Tensor& batch_ids_q, // token_num
|
||||
const int q_head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int block_size,
|
||||
paddle::Tensor& q_out,
|
||||
paddle::Tensor& k_out,
|
||||
paddle::Tensor& v_out) {
|
||||
typedef PDTraits<D> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
const int all_num_elements = qkv.numel();
|
||||
const int all_num_heads = q_head_num + 2 * kv_head_num;
|
||||
auto stream = qkv.stream();
|
||||
|
||||
dim3 block_dims(1, 4, (head_dim + VecSize - 1) / VecSize);
|
||||
dim3 grid_dims(all_num_elements / (all_num_heads * head_dim), // token
|
||||
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
|
||||
block_dims.y, // head
|
||||
(head_dim + (block_dims.z * VecSize) - 1) /
|
||||
(block_dims.z * VecSize) // dim: load Vec at a time
|
||||
);
|
||||
|
||||
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
|
||||
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
|
||||
|
||||
CacheKVWithRopeParams param;
|
||||
param.head_dim = head_dim;
|
||||
param.block_size = block_size;
|
||||
param.block_num = static_cast<int>(block_tables.shape().back());
|
||||
param.cache_stride = block_size * kv_head_num * head_dim;
|
||||
param.token_stride = all_num_heads * head_dim;
|
||||
param.head_stride = head_dim;
|
||||
param.q_stride = q_head_num * head_dim;
|
||||
param.kv_stride = kv_head_num * head_dim;
|
||||
param.q_head_offset = 0;
|
||||
param.k_head_offset = q_head_num;
|
||||
param.v_head_offset = q_head_num + kv_head_num;
|
||||
param.q_head_num = q_head_num;
|
||||
param.kv_head_num = kv_head_num;
|
||||
|
||||
if (qkv.dtype() == rotary_cos.dtype()) {
|
||||
DispatchCacheKVWithRopeVecKernel<DataType_, DataType_, VecSize>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
|
||||
reinterpret_cast<const int*>(block_tables.data<int>()),
|
||||
reinterpret_cast<const DataType_*>(rotary_cos.data<data_t>()),
|
||||
reinterpret_cast<const DataType_*>(rotary_sin.data<data_t>()),
|
||||
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
|
||||
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
|
||||
param,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||
} else if (rotary_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||
DispatchCacheKVWithRopeVecKernel<DataType_, float, VecSize>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
|
||||
reinterpret_cast<const int*>(block_tables.data<int>()),
|
||||
reinterpret_cast<const float*>(rotary_cos.data<float>()),
|
||||
reinterpret_cast<const float*>(rotary_sin.data<float>()),
|
||||
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
|
||||
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
|
||||
param,
|
||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||
} else {
|
||||
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||
}
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> CacheKVWithRope(
|
||||
const paddle::Tensor& qkv, // token_num, head_num * head_dim
|
||||
paddle::Tensor&
|
||||
caches_k, // max_block_num, block_size, kv_head_num, head_dim
|
||||
paddle::Tensor&
|
||||
caches_v, // max_block_num, block_size, kv_head_num, head_dim
|
||||
const paddle::Tensor& block_tables, // bs, block_num
|
||||
const paddle::Tensor& rotary_cos,
|
||||
const paddle::Tensor& rotary_sin,
|
||||
const paddle::Tensor& cu_seqlens_q, // bs + 1
|
||||
const paddle::Tensor& batch_ids_q, // token_num
|
||||
const int q_head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int block_size) {
|
||||
auto qkv_shape = qkv.shape();
|
||||
auto token_num = qkv_shape[0];
|
||||
auto place = qkv.place();
|
||||
auto dtype = qkv.dtype();
|
||||
common::DDim q_out_shape, kv_out_shape;
|
||||
if (rotary_cos.shape().size() == 3) {
|
||||
q_out_shape = {token_num, q_head_num, head_dim};
|
||||
kv_out_shape = {token_num, kv_head_num, head_dim};
|
||||
} else {
|
||||
q_out_shape = {token_num, 1, q_head_num, head_dim};
|
||||
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
|
||||
}
|
||||
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
|
||||
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||
|
||||
if (token_num == 0) {
|
||||
return {q_out, k_out, v_out};
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_EQ(qkv_shape.back(),
|
||||
((q_head_num + 2 * kv_head_num) * head_dim),
|
||||
"The last dimension of qkv [%d] must equal to {(q_head_num "
|
||||
"+ 2 * kv_head_num) * head_dim [%d].",
|
||||
qkv_shape.back(),
|
||||
((q_head_num + 2 * kv_head_num) * head_dim));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
head_dim % 2,
|
||||
0,
|
||||
"The last dimension (head_dim) of qkv must be an even number "
|
||||
"for RoPE, but got %d",
|
||||
head_dim);
|
||||
PADDLE_ENFORCE_EQ(q_out.shape().back(),
|
||||
rotary_cos.shape().back(),
|
||||
"The last dimension of cos mismatches that of q, "
|
||||
"expect %d but got %d",
|
||||
q_out.shape().back(),
|
||||
rotary_cos.shape().back());
|
||||
|
||||
switch (dtype) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
CacheKVWithRopeKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||
caches_k,
|
||||
caches_v,
|
||||
block_tables,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cu_seqlens_q,
|
||||
batch_ids_q,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
q_out,
|
||||
k_out,
|
||||
v_out);
|
||||
break;
|
||||
case paddle::DataType::FLOAT16:
|
||||
CacheKVWithRopeKernel<paddle::DataType::FLOAT16>(qkv,
|
||||
caches_k,
|
||||
caches_v,
|
||||
block_tables,
|
||||
rotary_cos,
|
||||
rotary_sin,
|
||||
cu_seqlens_q,
|
||||
batch_ids_q,
|
||||
q_head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
block_size,
|
||||
q_out,
|
||||
k_out,
|
||||
v_out);
|
||||
break;
|
||||
default:
|
||||
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||
}
|
||||
|
||||
return {q_out, k_out, v_out};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> CacheKVWithRopeInferShape(
|
||||
const std::vector<int64_t>& qkv_shape,
|
||||
const std::vector<int64_t>& caches_k_shape,
|
||||
const std::vector<int64_t>& caches_v_shape,
|
||||
const std::vector<int64_t>& block_tables_shape,
|
||||
const std::vector<int64_t>& cos_shape,
|
||||
const std::vector<int64_t>& sin_shape,
|
||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||
const std::vector<int64_t>& batch_ids_q_shape) {
|
||||
return {qkv_shape,
|
||||
caches_k_shape,
|
||||
caches_v_shape,
|
||||
block_tables_shape,
|
||||
cos_shape,
|
||||
sin_shape,
|
||||
cu_seqlens_q_shape,
|
||||
batch_ids_q_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> CacheKVWithRopeInferDtype(
|
||||
const paddle::DataType& qkv_dtype,
|
||||
const paddle::DataType& caches_k_dtype,
|
||||
const paddle::DataType& caches_v_dtype,
|
||||
const paddle::DataType& block_tables_dtype,
|
||||
const paddle::DataType& cos_dtype,
|
||||
const paddle::DataType& sin_dtype,
|
||||
const paddle::DataType& cu_seqlens_q_dtype,
|
||||
const paddle::DataType& batch_ids_q_dtype) {
|
||||
return {qkv_dtype,
|
||||
caches_k_dtype,
|
||||
caches_v_dtype,
|
||||
block_tables_dtype,
|
||||
cos_dtype,
|
||||
sin_dtype,
|
||||
cu_seqlens_q_dtype,
|
||||
batch_ids_q_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(cache_kv_with_rope)
|
||||
.Inputs({"qkv",
|
||||
"caches_k",
|
||||
"caches_v",
|
||||
"block_tables",
|
||||
"rotary_cos",
|
||||
"rotary_sin",
|
||||
"cu_seqlen_q",
|
||||
"batch_ids_q"})
|
||||
.Outputs({"q_out", "k_out", "v_out"})
|
||||
.Attrs(
|
||||
{"q_head_num:int", "kv_head_num:int", "head_dim:int", "block_size:int"})
|
||||
.SetKernelFn(PD_KERNEL(CacheKVWithRope))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(CacheKVWithRopeInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(CacheKVWithRopeInferDtype));
|
||||
@@ -14,9 +14,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_op.h"
|
||||
#include "fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
__global__ void compute_total_rows_before_expert_kernel(
|
||||
int* sorted_experts,
|
||||
@@ -42,58 +43,61 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||
}
|
||||
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
} // namespace phi
|
||||
|
||||
template <paddle::DataType T>
|
||||
void FusedMoeKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
paddle::Tensor* output) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
|
||||
auto* output_data = output->data<data_t>();
|
||||
|
||||
auto moe_compute =
|
||||
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
||||
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
|
||||
|
||||
moe_compute.computeFFN(&input,
|
||||
&gate_weight,
|
||||
&ffn1_weight,
|
||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
||||
&ffn2_weight,
|
||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
auto moe_compute =
|
||||
McMoeHelper<data_t, DataType_>(quant_method, &int8_moe_gemm_runner);
|
||||
|
||||
moe_compute.computeFFN(
|
||||
&input,
|
||||
&gate_weight,
|
||||
&up_gate_proj_weight,
|
||||
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
|
||||
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
|
||||
&down_proj_weight,
|
||||
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
|
||||
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
|
||||
nullptr,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
1.0, // ComputeFFN
|
||||
"ffn",
|
||||
output);
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gate_weight,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const std::string& quant_method,
|
||||
const int moe_topk,
|
||||
const bool norm_topk_prob,
|
||||
@@ -107,40 +111,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(input,
|
||||
gate_weight,
|
||||
ffn1_weight,
|
||||
ffn1_scale,
|
||||
ffn1_bias,
|
||||
ffn2_weight,
|
||||
ffn2_scale,
|
||||
ffn2_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gate_weight,
|
||||
up_gate_proj_weight,
|
||||
up_gate_proj_scale,
|
||||
up_gate_proj_bias,
|
||||
down_proj_weight,
|
||||
down_proj_scale,
|
||||
down_proj_bias,
|
||||
quant_method,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
norm_topk_prob,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gate_weight,
|
||||
// ffn1_weight,
|
||||
// ffn1_scale,
|
||||
// ffn1_bias,
|
||||
// ffn2_weight,
|
||||
// ffn2_scale,
|
||||
// ffn2_bias,
|
||||
// quant_method,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// norm_topk_prob,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for FusedMoeKernel");
|
||||
PD_THROW("Unsupported data type for FusedMoeKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
@@ -148,36 +134,36 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gate_weight_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
|
||||
return {input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gate_weight_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||
const paddle::DataType& down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
|
||||
return {input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(fused_expert_moe)
|
||||
PD_BUILD_STATIC_OP(fused_expert_moe)
|
||||
.Inputs({"input",
|
||||
"gate_weight",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_bias"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_bias"),
|
||||
paddle::Optional("down_proj_scale")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"quant_method:std::string",
|
||||
"moe_topk:int",
|
||||
|
||||
199
custom_ops/metax_ops/fused_moe_gemm_kernels.h
Normal file
199
custom_ops/metax_ops/fused_moe_gemm_kernels.h
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
|
||||
namespace phi {
|
||||
|
||||
template <typename T>
|
||||
struct mctlassExDataTraits;
|
||||
|
||||
template <>
|
||||
struct mctlassExDataTraits<maca_bfloat16> {
|
||||
static constexpr mctlassExDataType type =
|
||||
mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mctlassExDataTraits<int8_t> {
|
||||
static constexpr mctlassExDataType type =
|
||||
mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||
};
|
||||
|
||||
template <typename T, typename WeightType>
|
||||
class McMoeGemmRunner {
|
||||
public:
|
||||
McMoeGemmRunner() {}
|
||||
|
||||
void mc_grouped_gemm_basic_kernel(const T* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const WeightType* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const T* ptrScale,
|
||||
const T* ptrBias,
|
||||
T* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int* ptrSegInd,
|
||||
int* ptrMNumTilesInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
mctlassExDataType DataType_ = mctlassExDataTraits<T>::type;
|
||||
mctlassExDataType WeightType_ = mctlassExDataTraits<WeightType>::type;
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutA, DataType_, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutB, WeightType_, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(&matLayoutC, DataType_, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
mctlassExDesc_t mctlass_desc;
|
||||
mctlassExCreateDesc(&mctlass_desc);
|
||||
mctlassExDataType input_type = DataType_;
|
||||
mctlassExDataType scale_type = WeightType_;
|
||||
mctlassExDataType compute_type =
|
||||
mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||
mctlassExEpilogueType epilogue_type =
|
||||
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale,
|
||||
sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias,
|
||||
sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type,
|
||||
sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
||||
mctlassExContiguousGroupedGemmAlgo_t::
|
||||
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||
mctlassExContiguousGroupedDescCreate(
|
||||
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
||||
int blocksizeM;
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
&blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
contiguous_group_desc,
|
||||
numExperts,
|
||||
blocksizeM,
|
||||
stream);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle,
|
||||
mctlass_desc,
|
||||
ptrA,
|
||||
matLayoutA,
|
||||
ptrB,
|
||||
matLayoutB,
|
||||
ptrC,
|
||||
matLayoutC,
|
||||
contiguous_group_desc,
|
||||
&algo,
|
||||
nullptr,
|
||||
0,
|
||||
stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
}
|
||||
};
|
||||
|
||||
template class McMoeGemmRunner<maca_bfloat16, int8_t>;
|
||||
|
||||
} // namespace phi
|
||||
@@ -14,14 +14,17 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
||||
#include "fused_moe_gemm_kernels.h"
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_op.h"
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
|
||||
using namespace phi;
|
||||
namespace phi {
|
||||
|
||||
template <typename T, int VecSize>
|
||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
__global__ void moe_token_type_ids_kernel(T* gating_output,
|
||||
const int* moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k) {
|
||||
@@ -40,8 +43,8 @@ __global__ void moe_token_type_ids_kernel(T *gating_output,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
const int *moe_token_type_ids_out,
|
||||
void moe_token_type_ids_kernelLauncher(T* gating_output,
|
||||
const int* moe_token_type_ids_out,
|
||||
const int num_rows,
|
||||
const int num_experts,
|
||||
const int k,
|
||||
@@ -51,3 +54,338 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
||||
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||
}
|
||||
|
||||
template <typename T, typename MacaType>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method,
|
||||
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner)
|
||||
: gemm_method_(gemm_method),
|
||||
int8_moe_gemm_runner_(int8_moe_gemm_runner) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor* input,
|
||||
const paddle::Tensor* gate_weight,
|
||||
const paddle::Tensor* up_gate_proj_weight,
|
||||
const paddle::Tensor* up_gate_proj_scale,
|
||||
const paddle::Tensor* up_gate_proj_bias,
|
||||
const paddle::Tensor* down_proj_weight,
|
||||
const paddle::Tensor* down_proj_scale,
|
||||
const paddle::Tensor* down_proj_bias,
|
||||
const paddle::Tensor* moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor* output) {
|
||||
auto* input_activations = input->data<T>();
|
||||
auto* gating_weights = gate_weight->data<float>();
|
||||
const T* fc1_expert_biases =
|
||||
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||
const T* fc2_expert_biases =
|
||||
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||
|
||||
auto* output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto up_gate_proj_dims = up_gate_proj_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
const int64_t hidden_size = up_gate_proj_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim =
|
||||
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||
} else {
|
||||
inter_dim = up_gate_proj_dims[1];
|
||||
}
|
||||
|
||||
// if (gemm_method_ == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = up_gate_proj_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
// Pointers
|
||||
int* expert_for_source_row;
|
||||
int* source_rows_;
|
||||
int* permuted_rows_;
|
||||
int* permuted_experts_;
|
||||
int* expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
T* permuted_data_;
|
||||
int32_t* total_rows_before_expert_;
|
||||
T* fc1_result_;
|
||||
float* softmax_out_;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T*>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T* fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float* gating_output = gate_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
||||
nullptr,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
nullptr,
|
||||
nullptr,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
auto m_num_tile =
|
||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
||||
reinterpret_cast<const MacaType*>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const int8_t*>(up_gate_proj_weight->data<int8_t>()),
|
||||
column_major,
|
||||
reinterpret_cast<const MacaType*>(up_gate_proj_scale->data<T>()),
|
||||
reinterpret_cast<const MacaType*>(fc1_expert_biases),
|
||||
reinterpret_cast<MacaType*>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
||||
}
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T* fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
if (gemm_method_ == "weight_only_int8") {
|
||||
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
||||
reinterpret_cast<const MacaType*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const int8_t*>(down_proj_weight->data<int8_t>()),
|
||||
column_major,
|
||||
reinterpret_cast<const MacaType*>(down_proj_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<MacaType*>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
||||
}
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner_;
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
|
||||
} // namespace phi
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#include <string>
|
||||
#include "cub/cub.cuh"
|
||||
|
||||
namespace phi {
|
||||
|
||||
static const float HALF_FLT_MAX = 65504.F;
|
||||
static const float HALF_FLT_MIN = -65504.F;
|
||||
static inline size_t AlignTo16(const size_t& input) {
|
||||
@@ -121,3 +123,5 @@ class CubKeyValueSorter {
|
||||
int num_experts_;
|
||||
int num_bits_;
|
||||
};
|
||||
|
||||
} // namespace phi
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,486 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fused_moe_helper.h"
|
||||
#include "mctlass/numeric_conversion.h"
|
||||
#include "mctlassEx/mctlassEx.h"
|
||||
|
||||
template <typename ElementA, typename ElementB, typename ElementC>
|
||||
void mc_grouped_gemm_basic_kernel(const ElementA* ptrA,
|
||||
mctlassExOrder_t majorA,
|
||||
const ElementB* ptrB,
|
||||
mctlassExOrder_t majorB,
|
||||
const ElementA* ptrScale,
|
||||
const ElementA* ptrBias,
|
||||
ElementC* ptrC,
|
||||
mctlassExOrder_t majorC,
|
||||
const int* ptrSegInd,
|
||||
int* ptrMNumTilesInd,
|
||||
int numExperts,
|
||||
int m, // expanded_active_expert_rows
|
||||
int n, // inter_dim
|
||||
int k, // hidden_size
|
||||
mcStream_t stream) {
|
||||
mctlassExHandle_t handle;
|
||||
mctlassExHandleCreate(&handle);
|
||||
|
||||
mctlassExMatrixLayout_t matLayoutA;
|
||||
mctlassExMatrixLayout_t matLayoutB;
|
||||
mctlassExMatrixLayout_t matLayoutC;
|
||||
|
||||
// mat A: (m, k)
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorA,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutA,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat B: (num_experts, n, k)
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorB,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutB,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// mat C: (m, n)
|
||||
mctlassExMatrixLayoutCreate(
|
||||
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||
&majorC,
|
||||
sizeof(mctlassExOrder_t));
|
||||
mctlassExMatrixLayoutSetAttribute(
|
||||
matLayoutC,
|
||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||
&numExperts,
|
||||
sizeof(int));
|
||||
// bias: (num_experts, n)
|
||||
// scale: (num, n)
|
||||
|
||||
mctlassExDesc_t mctlass_desc;
|
||||
mctlassExCreateDesc(&mctlass_desc);
|
||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||
mctlassExEpilogueType epilogue_type =
|
||||
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||
if (ptrBias) {
|
||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||
}
|
||||
// set scale
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||
&ptrScale,
|
||||
sizeof(ptrScale));
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||
&input_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set bias
|
||||
if (ptrBias) {
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||
&ptrBias,
|
||||
sizeof(ptrBias));
|
||||
}
|
||||
// set coumpute type
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||
&compute_type,
|
||||
sizeof(mctlassExDataType));
|
||||
// set epilogue type
|
||||
mctlassExDescSetAttribute(
|
||||
mctlass_desc,
|
||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||
&epilogue_type,
|
||||
sizeof(mctlassExEpilogueType));
|
||||
|
||||
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
||||
mctlassExContiguousGroupedGemmAlgo_t::
|
||||
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||
mctlassExContiguousGroupedDescCreate(
|
||||
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
||||
int blocksizeM;
|
||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
&blocksizeM);
|
||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
||||
mctlass_desc,
|
||||
matLayoutA,
|
||||
matLayoutB,
|
||||
matLayoutC,
|
||||
&algo,
|
||||
contiguous_group_desc,
|
||||
numExperts,
|
||||
blocksizeM,
|
||||
stream);
|
||||
|
||||
mctlassExContiguousGroupedGemmBasic(handle,
|
||||
mctlass_desc,
|
||||
ptrA,
|
||||
matLayoutA,
|
||||
ptrB,
|
||||
matLayoutB,
|
||||
ptrC,
|
||||
matLayoutC,
|
||||
contiguous_group_desc,
|
||||
&algo,
|
||||
nullptr,
|
||||
0,
|
||||
stream);
|
||||
|
||||
mctlassExHandleDestroy(handle);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
||||
mctlassExDestroyDesc(mctlass_desc);
|
||||
}
|
||||
|
||||
template <typename T, typename ElementA, typename ElementB, typename ElementC>
|
||||
class McMoeHelper {
|
||||
public:
|
||||
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
|
||||
|
||||
// -------- getWorkspaceSize -------- //
|
||||
template <typename KeyT>
|
||||
size_t getWorkspaceSize(const int64_t num_rows,
|
||||
const int64_t hidden_size,
|
||||
const int64_t inter_size,
|
||||
const int64_t num_experts,
|
||||
const int64_t k) {
|
||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const size_t padded_experts = AlignTo16(num_experts);
|
||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||
// FfnLayer forward.
|
||||
size_t total_ws_bytes =
|
||||
5 * num_moe_inputs *
|
||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||
total_ws_bytes +=
|
||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||
|
||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||
const size_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||
sorter_.update_num_experts(num_experts);
|
||||
|
||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||
int64_t remaining_bytes =
|
||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||
}
|
||||
|
||||
total_ws_bytes +=
|
||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||
// sorting workspace
|
||||
|
||||
int64_t num_softmax_outs = 0;
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||
}
|
||||
|
||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||
|
||||
return total_ws_bytes;
|
||||
}
|
||||
|
||||
void computeFFN(const paddle::Tensor* input,
|
||||
const paddle::Tensor* gate_weight,
|
||||
const paddle::Tensor* ffn1_weight,
|
||||
const paddle::Tensor* ffn1_scale,
|
||||
const paddle::Tensor* ffn1_bias,
|
||||
const paddle::Tensor* ffn2_weight,
|
||||
const paddle::Tensor* ffn2_scale,
|
||||
const paddle::Tensor* ffn2_bias,
|
||||
const paddle::Tensor* moe_token_type_ids,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const std::string moe_type,
|
||||
paddle::Tensor* output) {
|
||||
auto* input_activations = input->data<T>();
|
||||
auto* gating_weights = gate_weight->data<float>();
|
||||
const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
||||
const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
||||
|
||||
auto* output_ = output->data<T>();
|
||||
auto stream = input->stream();
|
||||
auto place = input->place();
|
||||
auto input_type = input->dtype();
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto ffn1_dims = ffn1_weight->dims();
|
||||
int64_t token_num = 0;
|
||||
if (input_dims.size() == 3) {
|
||||
token_num = input_dims[0] * input_dims[1];
|
||||
} else {
|
||||
token_num = input_dims[0];
|
||||
}
|
||||
const int64_t num_rows = token_num;
|
||||
|
||||
const int64_t hidden_size = ffn1_dims[2];
|
||||
int64_t inter_dim = 0;
|
||||
if (moe_type == "qkv") {
|
||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
||||
} else {
|
||||
inter_dim = ffn1_dims[1];
|
||||
}
|
||||
|
||||
// if (gemm_method == "weight_only_int4") {
|
||||
// inter_dim = inter_dim * 2;
|
||||
// }
|
||||
|
||||
const int64_t inter_size = inter_dim;
|
||||
const int64_t num_experts = ffn1_dims[0];
|
||||
const int64_t k = moe_topk;
|
||||
|
||||
int64_t bytes =
|
||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||
|
||||
// Pointers
|
||||
int* expert_for_source_row;
|
||||
int* source_rows_;
|
||||
int* permuted_rows_;
|
||||
int* permuted_experts_;
|
||||
int* expanded_source_row_to_expanded_dest_row;
|
||||
|
||||
T* permuted_data_;
|
||||
int32_t* total_rows_before_expert_;
|
||||
T* fc1_result_;
|
||||
float* softmax_out_;
|
||||
|
||||
paddle::Tensor ws_ptr_tensor =
|
||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||
|
||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||
const int64_t padded_experts = AlignTo16(num_experts);
|
||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||
|
||||
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||
expanded_source_row_to_expanded_dest_row =
|
||||
permuted_experts_ + num_moe_inputs;
|
||||
permuted_data_ = reinterpret_cast<T*>(
|
||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||
total_rows_before_expert_ =
|
||||
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
||||
fc1_result_ =
|
||||
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
||||
|
||||
const bool is_pow_2 =
|
||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
if (!is_pow_2 || num_experts > 256) {
|
||||
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
||||
} else {
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
paddle::Tensor expert_scales_float_tensor =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
// (TODO: check fill success ?)
|
||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||
}
|
||||
|
||||
paddle::Tensor fc1_out_tensor =
|
||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||
T* fc1_out = fc1_out_tensor.data<T>();
|
||||
|
||||
auto input_cast_tensor =
|
||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||
auto gate_tensor =
|
||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||
float* gating_output = gate_tensor.data<float>();
|
||||
|
||||
if (moe_token_type_ids) {
|
||||
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||
moe_token_type_ids_out,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
stream);
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
||||
expert_scales_float,
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
num_experts,
|
||||
k,
|
||||
group_moe,
|
||||
stream);
|
||||
|
||||
const int64_t sorter_ws_size_bytes =
|
||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||
|
||||
sorter_.run(fc1_result_,
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
k * num_rows,
|
||||
false,
|
||||
stream);
|
||||
|
||||
initialize_moe_routing_kernelLauncher(
|
||||
input_activations,
|
||||
permuted_data_,
|
||||
permuted_rows_,
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
num_rows,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
stream);
|
||||
|
||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
expanded_active_expert_rows,
|
||||
num_experts,
|
||||
total_rows_before_expert_,
|
||||
stream);
|
||||
|
||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||
mctlassExOrder_t column_major =
|
||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||
auto m_num_tile =
|
||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA*>(permuted_data_),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn1_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(ffn1_scale->data<T>()),
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC*>(fc1_out),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_size,
|
||||
hidden_size,
|
||||
stream);
|
||||
|
||||
if (moe_type == "ffn") {
|
||||
auto act_out_tensor =
|
||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<T>();
|
||||
|
||||
paddle::Tensor fc2_output_tensor =
|
||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||
T* fc2_result = fc2_output_tensor.data<T>();
|
||||
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn2_weight->data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(ffn2_scale->data<T>()),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(fc2_result),
|
||||
row_major,
|
||||
total_rows_before_expert_,
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_size / 2,
|
||||
stream);
|
||||
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
fc2_result,
|
||||
output_,
|
||||
fc2_expert_biases,
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
hidden_size,
|
||||
k,
|
||||
static_cast<int>(1),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
} else {
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
// fc2_result,
|
||||
fc1_out,
|
||||
output_,
|
||||
fc1_expert_biases, // fc2_expert_biases,
|
||||
reinterpret_cast<float*>(expert_scales_float),
|
||||
expanded_source_row_to_expanded_dest_row,
|
||||
expert_for_source_row,
|
||||
num_rows,
|
||||
inter_size,
|
||||
k,
|
||||
static_cast<int>(0),
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string gemm_method_;
|
||||
CubKeyValueSorter sorter_;
|
||||
};
|
||||
@@ -17,26 +17,35 @@
|
||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_imp_op.h"
|
||||
#include "fused_moe_op.h"
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
template <paddle::DataType T>
|
||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* top_k_weight,
|
||||
paddle::Tensor* top_k_indices) {
|
||||
void MoeDispatchKernel(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const bool topk_only_mode,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int expert_num,
|
||||
paddle::Tensor* permute_input,
|
||||
paddle::Tensor* tokens_expert_prefix_sum,
|
||||
paddle::Tensor* permute_indices_per_token,
|
||||
paddle::Tensor* topk_weight,
|
||||
paddle::Tensor* topk_idx,
|
||||
paddle::Tensor* expert_idx_per_token) {
|
||||
using namespace phi;
|
||||
|
||||
if (num_rows == 0) {
|
||||
return;
|
||||
}
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -78,7 +87,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||
|
||||
int* expert_for_source_row = top_k_indices->data<int>();
|
||||
int* topk_idx_ptr = topk_idx->data<int>();
|
||||
|
||||
float* softmax_max_prob = nullptr;
|
||||
if (group_moe) {
|
||||
@@ -103,23 +112,25 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
softmax_out_ = nullptr;
|
||||
}
|
||||
|
||||
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
||||
top_k_weight->data<float>(),
|
||||
softmax_out_,
|
||||
expert_for_source_row,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
topk_gating_softmax_kernelLauncher(
|
||||
gating_output.data<float>(),
|
||||
static_cast<const float*>(nullptr), // no gating_correction_bias
|
||||
topk_weight->data<float>(),
|
||||
softmax_out_,
|
||||
topk_idx_ptr,
|
||||
source_rows_,
|
||||
softmax_max_prob,
|
||||
num_rows,
|
||||
expert_num,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
stream,
|
||||
topk_only_mode);
|
||||
|
||||
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||
sorter_ws_size_bytes,
|
||||
expert_for_source_row,
|
||||
permuted_experts_,
|
||||
topk_idx_ptr,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
source_rows_,
|
||||
permuted_rows_,
|
||||
moe_topk * num_rows,
|
||||
@@ -130,6 +141,8 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
input.data<data_t>(),
|
||||
permute_input->data<data_t>(),
|
||||
permuted_rows_,
|
||||
expert_idx_per_token->data<int32_t>(),
|
||||
nullptr,
|
||||
permute_indices_per_token->data<int32_t>(),
|
||||
num_rows,
|
||||
num_rows,
|
||||
@@ -137,7 +150,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
moe_topk,
|
||||
stream);
|
||||
|
||||
compute_total_rows_before_expert(permuted_experts_,
|
||||
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
|
||||
moe_topk * num_rows,
|
||||
expert_num,
|
||||
tokens_expert_prefix_sum->data<int32_t>(),
|
||||
@@ -147,8 +160,11 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
const paddle::Tensor& input,
|
||||
const paddle::Tensor& gating_output,
|
||||
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||
const int moe_topk,
|
||||
const bool group_moe,
|
||||
const std::string& moe_quant_type,
|
||||
const bool topk_only_mode) {
|
||||
const auto input_type = input.dtype();
|
||||
auto place = input.place();
|
||||
@@ -168,9 +184,9 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_input =
|
||||
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
||||
// correspond to the weighted coefficients of the results from each expert.
|
||||
auto top_k_weight =
|
||||
auto topk_weight =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||
auto top_k_indices =
|
||||
auto topk_idx =
|
||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
auto tokens_expert_prefix_sum =
|
||||
@@ -178,18 +194,24 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
auto expert_idx_per_token =
|
||||
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0) {
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
top_k_weight,
|
||||
top_k_indices};
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
gating_output,
|
||||
gating_correction_bias,
|
||||
w4a8_in_scale,
|
||||
moe_topk,
|
||||
group_moe,
|
||||
topk_only_mode,
|
||||
@@ -199,37 +221,25 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
&permute_input,
|
||||
&tokens_expert_prefix_sum,
|
||||
&permute_indices_per_token,
|
||||
&top_k_weight,
|
||||
&top_k_indices);
|
||||
&topk_weight,
|
||||
&topk_idx,
|
||||
&expert_idx_per_token);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
||||
// gating_output,
|
||||
// moe_topk,
|
||||
// group_moe,
|
||||
// topk_only_mode,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// expert_num,
|
||||
// &permute_input,
|
||||
// &tokens_expert_prefix_sum,
|
||||
// &permute_indices_per_token,
|
||||
// &top_k_weight,
|
||||
// &top_k_indices);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||
}
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
top_k_weight,
|
||||
top_k_indices};
|
||||
topk_weight,
|
||||
topk_idx,
|
||||
expert_idx_per_token};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const std::vector<int64_t>& input_shape,
|
||||
const std::vector<int64_t>& gating_output_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||
const int moe_topk) {
|
||||
int token_rows = -1;
|
||||
|
||||
@@ -241,33 +251,44 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
|
||||
|
||||
return {{moe_topk * num_rows, hidden_size},
|
||||
return {{permuted_rows, hidden_size},
|
||||
{expert_num},
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
{num_rows, moe_topk},
|
||||
{permuted_rows}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||
const paddle::DataType& input_dtype,
|
||||
const paddle::DataType& gating_output_dtype,
|
||||
const paddle::optional<paddle::DataType>& bias_type,
|
||||
const int moe_topk) {
|
||||
return {input_dtype,
|
||||
paddle::DataType::INT64,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32,
|
||||
paddle::DataType::INT32,
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moe_expert_dispatch)
|
||||
.Inputs({"input", "gating_output"})
|
||||
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
.Inputs({"input",
|
||||
"gating_output",
|
||||
paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Outputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token",
|
||||
"top_k_weight",
|
||||
"top_k_indices"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
"topk_weight",
|
||||
"topk_idx",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int",
|
||||
"group_moe:bool",
|
||||
"moe_quant_type:std::string",
|
||||
"topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||
|
||||
@@ -12,23 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// BUILD_MARK
|
||||
#pragma once
|
||||
#include "fused_moe_helper.h"
|
||||
#include "helper.h"
|
||||
#include "mc_fused_moe_helper.h"
|
||||
|
||||
template <paddle::DataType T,
|
||||
typename ElementA,
|
||||
typename ElementB,
|
||||
typename ElementC>
|
||||
void McMoeFFNKernel(paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const std::string& quant_method) {
|
||||
template <paddle::DataType T>
|
||||
void MoeFFNKernel(paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -38,11 +36,13 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
|
||||
auto input_type = permute_input.dtype();
|
||||
auto stream = permute_input.stream();
|
||||
|
||||
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
|
||||
|
||||
const int expanded_active_expert_rows =
|
||||
permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
||||
permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||
const int num_experts = up_gate_proj_weight.dims()[0]; // batchsize
|
||||
const int hidden_size = up_gate_proj_weight.dims()[2]; // n
|
||||
int inter_dim = up_gate_proj_weight.dims()[1]; // k
|
||||
|
||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||
@@ -58,60 +58,71 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
|
||||
|
||||
// ffn1
|
||||
auto fc1_expert_biases =
|
||||
ffn1_bias
|
||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
||||
up_gate_proj_bias
|
||||
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())
|
||||
->data<data_t>()
|
||||
: nullptr;
|
||||
auto fc1_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA*>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_scales),
|
||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
||||
reinterpret_cast<ElementC*>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())->data<data_t>();
|
||||
if (quant_method == "weight_only_int8") {
|
||||
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
|
||||
reinterpret_cast<const DataType_*>(permuted_input_ptr),
|
||||
row_major,
|
||||
reinterpret_cast<const int8_t*>(up_gate_proj_weight.data<int8_t>()),
|
||||
column_major,
|
||||
reinterpret_cast<const DataType_*>(fc1_expert_scales),
|
||||
reinterpret_cast<const DataType_*>(fc1_expert_biases),
|
||||
reinterpret_cast<DataType_*>(fc1_out_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
inter_dim,
|
||||
hidden_size,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported gemm method: " + quant_method);
|
||||
}
|
||||
|
||||
// swiglu
|
||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||
auto act_out = act_out_tensor.data<data_t>();
|
||||
|
||||
auto fc2_expert_scales =
|
||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
||||
reinterpret_cast<const ElementA*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
|
||||
column_major,
|
||||
reinterpret_cast<const ElementA*>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<ElementC*>(permuted_input_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<data_t>();
|
||||
|
||||
if (quant_method == "weight_only_int8") {
|
||||
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
|
||||
reinterpret_cast<const DataType_*>(act_out),
|
||||
row_major,
|
||||
reinterpret_cast<const int8_t*>(down_proj_weight.data<int8_t>()),
|
||||
column_major,
|
||||
reinterpret_cast<const DataType_*>(fc2_expert_scales),
|
||||
nullptr,
|
||||
reinterpret_cast<DataType_*>(permuted_input_ptr),
|
||||
row_major,
|
||||
tokens_expert_prefix_sum.data<int>(),
|
||||
m_num_tile_ptr,
|
||||
num_experts,
|
||||
expanded_active_expert_rows,
|
||||
hidden_size,
|
||||
inter_dim / 2,
|
||||
stream);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported gemm method: " + quant_method);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
paddle::Tensor& permute_input,
|
||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||
const paddle::Tensor& ffn1_weight,
|
||||
const paddle::Tensor& ffn2_weight,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
||||
const paddle::Tensor& up_gate_proj_weight,
|
||||
const paddle::Tensor& down_proj_weight,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||
const std::string& quant_method) {
|
||||
assert(quant_method == "weight_only_int8");
|
||||
const auto input_type = permute_input.dtype();
|
||||
@@ -122,31 +133,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16,
|
||||
maca_bfloat16,
|
||||
int8_t,
|
||||
maca_bfloat16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
ffn1_weight,
|
||||
ffn2_weight,
|
||||
ffn1_bias,
|
||||
ffn1_scale,
|
||||
ffn2_scale,
|
||||
quant_method);
|
||||
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
up_gate_proj_weight,
|
||||
down_proj_weight,
|
||||
up_gate_proj_bias,
|
||||
up_gate_proj_scale,
|
||||
down_proj_scale,
|
||||
expert_idx_per_token,
|
||||
quant_method);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
||||
// tokens_expert_prefix_sum,
|
||||
// ffn1_weight,
|
||||
// ffn2_weight,
|
||||
// ffn1_bias,
|
||||
// ffn1_scale,
|
||||
// ffn2_scale,
|
||||
// quant_method,
|
||||
// ffn_out);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
||||
PD_THROW("Unsupported data type for MoeFFNhKernel");
|
||||
}
|
||||
return {permute_input};
|
||||
}
|
||||
@@ -154,33 +152,37 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||
const std::vector<int64_t>& permute_input_shape,
|
||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||
const std::vector<int64_t>& ffn1_weight_shape,
|
||||
const std::vector<int64_t>& ffn2_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
||||
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||
const std::vector<int64_t>& down_proj_weight_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||
const std::string& quant_method) {
|
||||
return {permute_input_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||
const paddle::DataType& permute_input_dtype,
|
||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||
const paddle::DataType& ffn1_weight_dtype,
|
||||
const paddle::DataType& ffn2_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
||||
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||
const paddle::DataType& down_proj_weight_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
|
||||
const paddle::optional<paddle::DataType>& expert_idx_per_token_dtype) {
|
||||
return {permute_input_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OP(moe_expert_ffn)
|
||||
.Inputs({"permute_input",
|
||||
"tokens_expert_prefix_sum",
|
||||
"ffn1_weight",
|
||||
"ffn2_weight",
|
||||
paddle::Optional("ffn1_bias"),
|
||||
paddle::Optional("ffn1_scale"),
|
||||
paddle::Optional("ffn2_scale")})
|
||||
"up_gate_proj_weight",
|
||||
"down_proj_weight",
|
||||
paddle::Optional("up_gate_proj_bias"),
|
||||
paddle::Optional("up_gate_proj_scale"),
|
||||
paddle::Optional("down_proj_scale"),
|
||||
paddle::Optional("expert_idx_per_token")})
|
||||
.Outputs({"output_tensor"})
|
||||
.Attrs({"quant_method:std::string"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moe_helper.h"
|
||||
#include "fused_moe_op.h"
|
||||
#include "helper.h"
|
||||
|
||||
@@ -23,13 +22,14 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor,
|
||||
const int num_rows,
|
||||
const int hidden_size,
|
||||
const int topk,
|
||||
paddle::Tensor* output) {
|
||||
using namespace phi;
|
||||
typedef PDTraits<T> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
||||
finalize_moe_routing_kernelLauncher(
|
||||
ffn_out.data<data_t>(),
|
||||
output->data<data_t>(),
|
||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
||||
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||
top_k_weight.data<float>(),
|
||||
permute_indices_per_token.data<int32_t>(),
|
||||
top_k_indices.data<int>(),
|
||||
@@ -56,7 +56,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
const paddle::Tensor& top_k_weight,
|
||||
const paddle::Tensor& permute_indices_per_token,
|
||||
const paddle::Tensor& top_k_indices,
|
||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
||||
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||
const bool norm_topk_prob,
|
||||
const float routed_scaling_factor) {
|
||||
const auto input_type = ffn_out.dtype();
|
||||
@@ -69,7 +69,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||
|
||||
// Avoids ‘invalid configuration argument’ when we launch the kernel.
|
||||
if (ffn_out.dims()[0] == 0) return {output};
|
||||
if (num_rows == 0) return {output};
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
@@ -77,7 +77,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
top_k_weight,
|
||||
permute_indices_per_token,
|
||||
top_k_indices,
|
||||
ffn2_bias,
|
||||
down_proj_bias,
|
||||
norm_topk_prob,
|
||||
routed_scaling_factor,
|
||||
num_rows,
|
||||
@@ -85,21 +85,8 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
||||
topk,
|
||||
&output);
|
||||
break;
|
||||
// case paddle::DataType::FLOAT16:
|
||||
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
|
||||
// top_k_weight,
|
||||
// permute_indices_per_token,
|
||||
// top_k_indices,
|
||||
// ffn2_bias,
|
||||
// norm_topk_prob,
|
||||
// routed_scaling_factor,
|
||||
// num_rows,
|
||||
// hidden_size,
|
||||
// topk,
|
||||
// &output);
|
||||
// break;
|
||||
default:
|
||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
||||
PD_THROW("Unsupported data type for MoeReduceKernel");
|
||||
}
|
||||
return {output};
|
||||
}
|
||||
@@ -109,7 +96,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
||||
const std::vector<int64_t>& top_k_weight_shape,
|
||||
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||
const std::vector<int64_t>& top_k_indices_shape,
|
||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
|
||||
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
|
||||
const int topk = top_k_indices_shape[1];
|
||||
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
|
||||
ffn_out_shape[1]};
|
||||
@@ -122,7 +109,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
||||
const paddle::DataType& top_k_weight_dtype,
|
||||
const paddle::DataType& permute_indices_per_token_dtype,
|
||||
const paddle::DataType& top_k_indices_dtype,
|
||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
|
||||
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
|
||||
return {ffn_out_dtype};
|
||||
}
|
||||
|
||||
@@ -131,7 +118,7 @@ PD_BUILD_OP(moe_expert_reduce)
|
||||
"top_k_weight",
|
||||
"permute_indices_per_token",
|
||||
"top_k_indices",
|
||||
paddle::Optional("ffn2_bias")})
|
||||
paddle::Optional("down_proj_bias")})
|
||||
.Outputs({"output"})
|
||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||
|
||||
@@ -627,11 +627,17 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
|
||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||
"gpu_ops/moe/moe_topk_select.cu",
|
||||
"gpu_ops/get_img_boundaries.cc",
|
||||
"gpu_ops/remote_cache_kv_ipc.cc",
|
||||
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
|
||||
"metax_ops/moe_dispatch.cu",
|
||||
"metax_ops/moe_ffn.cu",
|
||||
"metax_ops/moe_reduce.cu",
|
||||
"metax_ops/fused_moe.cu",
|
||||
"metax_ops/apply_rope.cu",
|
||||
"metax_ops/apply_rope_qkv.cu",
|
||||
"metax_ops/cache_kv_with_rope.cu",
|
||||
]
|
||||
|
||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||
@@ -657,6 +663,11 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
||||
os.path.join(maca_path, "include"),
|
||||
os.path.join(maca_path, "include/mcr"),
|
||||
os.path.join(maca_path, "include/common"),
|
||||
os.path.join(maca_path, "include/mcfft"),
|
||||
os.path.join(maca_path, "include/mcrand"),
|
||||
os.path.join(maca_path, "include/mcsparse"),
|
||||
os.path.join(maca_path, "include/mcblas"),
|
||||
os.path.join(maca_path, "include/mcsolver"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user