mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] refactor cutlass moe and optimize flash attention (#5361)
* [Metax] refactor moe and flash attention backend --------- Co-authored-by: zhangchenyi_dl <16219492+zhangchenyidl@user.noreply.gitee.com>
This commit is contained in:
@@ -48,14 +48,15 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
|
|||||||
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
#define SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T> struct Pair {
|
template <typename T>
|
||||||
|
struct Pair {
|
||||||
T value;
|
T value;
|
||||||
int count;
|
int count;
|
||||||
|
|
||||||
__device__ Pair operator+(const Pair &other) const {
|
__device__ Pair operator+(const Pair& other) const {
|
||||||
return {value + other.value, count + other.count};
|
return {value + other.value, count + other.count};
|
||||||
}
|
}
|
||||||
__device__ Pair &operator+=(const Pair &other) {
|
__device__ Pair& operator+=(const Pair& other) {
|
||||||
value += other.value;
|
value += other.value;
|
||||||
count += other.count;
|
count += other.count;
|
||||||
return *this;
|
return *this;
|
||||||
@@ -78,22 +79,25 @@ struct ValueCount {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct BoolDiffOp {
|
struct BoolDiffOp {
|
||||||
__device__ __forceinline__ bool operator()(const bool &lhs,
|
__device__ __forceinline__ bool operator()(const bool& lhs,
|
||||||
const bool &rhs) const {
|
const bool& rhs) const {
|
||||||
return lhs != rhs;
|
return lhs != rhs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t BLOCK_THREADS,
|
||||||
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM>
|
BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||||
struct SamplingTempStorage {
|
struct SamplingTempStorage {
|
||||||
union {
|
union {
|
||||||
float deterministic_scan[BLOCK_THREADS / 32];
|
float deterministic_scan[BLOCK_THREADS / 32];
|
||||||
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
|
typename BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
|
||||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
reduce;
|
||||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||||
reduce_value_count;
|
reduce_int;
|
||||||
|
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||||
|
TempStorage reduce_value_count;
|
||||||
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
|
typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
|
||||||
} block_prim;
|
} block_prim;
|
||||||
struct {
|
struct {
|
||||||
@@ -112,14 +116,17 @@ struct SamplingTempStorage {
|
|||||||
* algorithm. \note This implementation is slower than the cub::BlockScan, but
|
* algorithm. \note This implementation is slower than the cub::BlockScan, but
|
||||||
* it is deterministic.
|
* it is deterministic.
|
||||||
*/
|
*/
|
||||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
|
template <uint32_t VEC_SIZE,
|
||||||
|
uint32_t BLOCK_THREADS,
|
||||||
BlockScanAlgorithm SCAN_ALGORITHM,
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
__device__ __forceinline__ void
|
typename T>
|
||||||
DeterministicInclusiveSum(const T *in_data, T *out_data,
|
__device__ __forceinline__ void DeterministicInclusiveSum(
|
||||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM,
|
const T* in_data,
|
||||||
REDUCE_ALGORITHM> *temp_storage) {
|
T* out_data,
|
||||||
T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||||
|
temp_storage) {
|
||||||
|
T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
|
||||||
T thread_data[VEC_SIZE];
|
T thread_data[VEC_SIZE];
|
||||||
T thread_sum = 0;
|
T thread_sum = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -138,8 +145,8 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
|
T warp_sum = __shfl_sync(
|
||||||
threadIdx.x | 0xffffffff);
|
0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff);
|
||||||
if (threadIdx.x % 32 == 31) {
|
if (threadIdx.x % 32 == 31) {
|
||||||
thread_exclusive_prefix_sum = 0;
|
thread_exclusive_prefix_sum = 0;
|
||||||
}
|
}
|
||||||
@@ -197,12 +204,21 @@ DeterministicInclusiveSum(const T *in_data, T *out_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t VEC_SIZE,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename Predicate>
|
uint32_t BLOCK_THREADS,
|
||||||
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
|
bool DETERMINISTIC,
|
||||||
|
typename Predicate>
|
||||||
__device__ __forceinline__ void DeviceSamplingFromProb(
|
__device__ __forceinline__ void DeviceSamplingFromProb(
|
||||||
uint32_t i, uint32_t d, Predicate pred, float u, vec_t<float, VEC_SIZE> prob_vec,
|
uint32_t i,
|
||||||
|
uint32_t d,
|
||||||
|
Predicate pred,
|
||||||
|
float u,
|
||||||
|
vec_t<float, VEC_SIZE> prob_vec,
|
||||||
float& aggregate,
|
float& aggregate,
|
||||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>* temp_storage) {
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
|
||||||
|
temp_storage) {
|
||||||
const uint32_t tx = threadIdx.x;
|
const uint32_t tx = threadIdx.x;
|
||||||
float prob_greater_than_threshold[VEC_SIZE];
|
float prob_greater_than_threshold[VEC_SIZE];
|
||||||
float inclusive_cdf[VEC_SIZE];
|
float inclusive_cdf[VEC_SIZE];
|
||||||
@@ -212,14 +228,14 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0;
|
||||||
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
|
||||||
}
|
}
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
float aggregate_local =
|
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
temp_storage->block_prim.reduce)
|
||||||
.Sum(prob_greater_than_threshold);
|
.Sum(prob_greater_than_threshold);
|
||||||
#else
|
#else
|
||||||
float aggregate_local =
|
float aggregate_local = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce)
|
temp_storage->block_prim.reduce)
|
||||||
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
.Sum<VEC_SIZE>(prob_greater_than_threshold);
|
||||||
#endif
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage->block_aggregate.value = aggregate_local;
|
temp_storage->block_aggregate.value = aggregate_local;
|
||||||
@@ -229,14 +245,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
|
|
||||||
if (aggregate + aggregate_local > u) {
|
if (aggregate + aggregate_local > u) {
|
||||||
if constexpr (DETERMINISTIC) {
|
if constexpr (DETERMINISTIC) {
|
||||||
DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>(
|
DeterministicInclusiveSum<VEC_SIZE,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
SCAN_ALGORITHM,
|
||||||
|
REDUCE_ALGORITHM>(
|
||||||
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
prob_greater_than_threshold, inclusive_cdf, temp_storage);
|
||||||
} else {
|
} else {
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||||
|
temp_storage->block_prim.scan)
|
||||||
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
|
.InclusiveSum(prob_greater_than_threshold, inclusive_cdf);
|
||||||
#else
|
#else
|
||||||
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
|
BlockScan<float, BLOCK_THREADS, SCAN_ALGORITHM>(
|
||||||
|
temp_storage->block_prim.scan)
|
||||||
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
.InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -250,28 +271,35 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
|
|
||||||
bool greater_than_u_diff[VEC_SIZE];
|
bool greater_than_u_diff[VEC_SIZE];
|
||||||
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
#ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||||
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
temp_storage->block_prim.adj_diff)
|
||||||
#else
|
.SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
|
||||||
.SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff, BoolDiffOp());
|
|
||||||
#endif
|
|
||||||
#else
|
#else
|
||||||
#ifdef PADDLE_WITH_COREX
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
temp_storage->block_prim.adj_diff)
|
||||||
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
.SubtractLeft<VEC_SIZE>(
|
||||||
#else
|
greater_than_u, greater_than_u_diff, BoolDiffOp());
|
||||||
BlockAdjacentDifference<bool, BLOCK_THREADS>(temp_storage->block_prim.adj_diff)
|
#endif
|
||||||
.FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
#else
|
||||||
#endif
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||||
|
temp_storage->block_prim.adj_diff)
|
||||||
|
.FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||||
|
#else
|
||||||
|
BlockAdjacentDifference<bool, BLOCK_THREADS>(
|
||||||
|
temp_storage->block_prim.adj_diff)
|
||||||
|
.FlagHeads<VEC_SIZE>(
|
||||||
|
greater_than_u_diff, greater_than_u, BoolDiffOp(), 0);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
if (greater_than_u_diff[j]) {
|
if (greater_than_u_diff[j]) {
|
||||||
atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
atomicMin(&(temp_storage->sampled_id),
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@@ -287,9 +315,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
valid_index[j] = -1;
|
valid_index[j] = -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int max_valid_index =
|
int max_valid_index = BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage->block_prim.reduce_int)
|
temp_storage->block_prim.reduce_int)
|
||||||
.Reduce(valid_index, cub::Max());
|
.Reduce(valid_index, cub::Max());
|
||||||
if (tx == 0 && max_valid_index != -1) {
|
if (tx == 0 && max_valid_index != -1) {
|
||||||
temp_storage->last_valid_id = max_valid_index;
|
temp_storage->last_valid_id = max_valid_index;
|
||||||
}
|
}
|
||||||
@@ -297,15 +325,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
|
|||||||
aggregate += aggregate_local;
|
aggregate += aggregate_local;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS,
|
||||||
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
uint32_t VEC_SIZE,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
bool DETERMINISTIC,
|
||||||
typename DType, typename IdType>
|
typename DType,
|
||||||
__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
typename IdType>
|
||||||
float* top_p_arr, IdType* top_k_arr,
|
__global__ void TopKTopPSamplingFromProbKernel(DType* probs,
|
||||||
uint32_t d, uint64_t philox_seed,
|
IdType* output,
|
||||||
|
float* top_p_arr,
|
||||||
|
IdType* top_k_arr,
|
||||||
|
uint32_t d,
|
||||||
|
uint64_t philox_seed,
|
||||||
uint64_t philox_offset) {
|
uint64_t philox_offset) {
|
||||||
const uint32_t batch_size = gridDim.x;
|
const uint32_t batch_size = gridDim.x;
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
@@ -315,12 +347,12 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||||
const float p = top_p_arr[row_idx];
|
const float p = top_p_arr[row_idx];
|
||||||
|
|
||||||
extern __shared__ __align__(
|
extern __shared__ __align__(alignof(
|
||||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
uint8_t smem_sampling[];
|
uint8_t smem_sampling[];
|
||||||
auto& temp_storage =
|
auto& temp_storage = reinterpret_cast<
|
||||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||||
smem_sampling);
|
smem_sampling);
|
||||||
|
|
||||||
vec_t<float, VEC_SIZE> probs_vec;
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
float aggregate;
|
float aggregate;
|
||||||
@@ -336,12 +368,22 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
DeviceSamplingFromProb<VEC_SIZE,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
SCAN_ALGORITHM,
|
||||||
|
REDUCE_ALGORITHM,
|
||||||
DETERMINISTIC>(
|
DETERMINISTIC>(
|
||||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
i,
|
||||||
|
d,
|
||||||
|
[&](float x) { return x > low; },
|
||||||
|
u,
|
||||||
|
probs_vec,
|
||||||
|
aggregate,
|
||||||
|
&temp_storage);
|
||||||
if (aggregate > u) {
|
if (aggregate > u) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -362,28 +404,29 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
ValueCount<float> probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
probs_gt_pivot_0[j] = {
|
probs_gt_pivot_0[j] = {(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
(probs_vec[j] > pivot_0 &&
|
||||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
probs_gt_pivot_1[j] = {
|
probs_gt_pivot_1[j] = {(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
(probs_vec[j] > pivot_1 &&
|
||||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_0 +=
|
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum(probs_gt_pivot_0);
|
.Sum(probs_gt_pivot_0);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_0 +=
|
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||||
#endif
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
temp_storage.block_aggregate.pair = aggregate_gt_pivot_0;
|
||||||
@@ -391,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair;
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_1 +=
|
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum(probs_gt_pivot_1);
|
.Sum(probs_gt_pivot_1);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_1 +=
|
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS>(
|
||||||
BlockReduce<ValueCount<float>, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count)
|
temp_storage.block_prim.reduce_value_count)
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||||
#endif
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
temp_storage.block_aggregate.pair = aggregate_gt_pivot_1;
|
||||||
@@ -427,14 +470,19 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS,
|
||||||
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
uint32_t VEC_SIZE,
|
||||||
bool DETERMINISTIC, typename DType, typename IdType>
|
bool DETERMINISTIC,
|
||||||
__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
typename DType,
|
||||||
float* top_p_arr, uint32_t d,
|
typename IdType>
|
||||||
uint64_t philox_seed, uint64_t philox_offset) {
|
__global__ void TopPSamplingFromProbKernel(DType* probs,
|
||||||
|
IdType* output,
|
||||||
|
float* top_p_arr,
|
||||||
|
uint32_t d,
|
||||||
|
uint64_t philox_seed,
|
||||||
|
uint64_t philox_offset) {
|
||||||
const uint32_t batch_size = gridDim.x;
|
const uint32_t batch_size = gridDim.x;
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
curandStatePhilox4_32_10_t state;
|
curandStatePhilox4_32_10_t state;
|
||||||
@@ -442,12 +490,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
const uint32_t row_idx = bx;
|
const uint32_t row_idx = bx;
|
||||||
float top_p = top_p_arr[row_idx];
|
float top_p = top_p_arr[row_idx];
|
||||||
|
|
||||||
extern __shared__ __align__(
|
extern __shared__ __align__(alignof(
|
||||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
uint8_t smem_sampling[];
|
uint8_t smem_sampling[];
|
||||||
auto& temp_storage =
|
auto& temp_storage = reinterpret_cast<
|
||||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||||
smem_sampling);
|
smem_sampling);
|
||||||
|
|
||||||
vec_t<float, VEC_SIZE> probs_vec;
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
float aggregate;
|
float aggregate;
|
||||||
@@ -463,12 +511,22 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM,
|
DeviceSamplingFromProb<VEC_SIZE,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
SCAN_ALGORITHM,
|
||||||
|
REDUCE_ALGORITHM,
|
||||||
DETERMINISTIC>(
|
DETERMINISTIC>(
|
||||||
i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage);
|
i,
|
||||||
|
d,
|
||||||
|
[&](float x) { return x > low; },
|
||||||
|
u,
|
||||||
|
probs_vec,
|
||||||
|
aggregate,
|
||||||
|
&temp_storage);
|
||||||
if (aggregate > u) {
|
if (aggregate > u) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -489,7 +547,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE];
|
||||||
@@ -499,12 +558,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_0 +=
|
||||||
.Sum(probs_gt_pivot_0);
|
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum(probs_gt_pivot_0);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_0 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_0 +=
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_0);
|
||||||
#endif
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
temp_storage.block_aggregate.value = aggregate_gt_pivot_0;
|
||||||
@@ -512,12 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
aggregate_gt_pivot_0 = temp_storage.block_aggregate.value;
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_1 +=
|
||||||
.Sum(probs_gt_pivot_1);
|
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum(probs_gt_pivot_1);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_1 += BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
aggregate_gt_pivot_1 +=
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
BlockReduce<float, BLOCK_THREADS>(temp_storage.block_prim.reduce)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_1);
|
||||||
#endif
|
#endif
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
temp_storage.block_aggregate.value = aggregate_gt_pivot_1;
|
||||||
@@ -546,9 +609,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
|
template <uint32_t VEC_SIZE,
|
||||||
|
uint32_t BLOCK_THREADS,
|
||||||
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
typename TempStorage>
|
typename TempStorage>
|
||||||
__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d,
|
__device__ __forceinline__ float GetMaxValue(float* in_data,
|
||||||
|
uint32_t row_idx,
|
||||||
|
uint32_t d,
|
||||||
TempStorage& temp_storage) {
|
TempStorage& temp_storage) {
|
||||||
const uint32_t tx = threadIdx.x;
|
const uint32_t tx = threadIdx.x;
|
||||||
vec_t<float, VEC_SIZE> in_data_vec;
|
vec_t<float, VEC_SIZE> in_data_vec;
|
||||||
@@ -557,21 +624,24 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
in_data_vec.fill(0);
|
in_data_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
in_data_vec.cast_load(in_data + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
float in_data_[VEC_SIZE];
|
float in_data_[VEC_SIZE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
in_data_[j] = in_data_vec[j];
|
in_data_[j] = in_data_vec[j];
|
||||||
}
|
}
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
max_val = max(
|
max_val = max(max_val,
|
||||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Reduce(in_data_, cub::Max()));
|
temp_storage.block_prim.reduce)
|
||||||
|
.Reduce(in_data_, cub::Max()));
|
||||||
#else
|
#else
|
||||||
max_val = max(
|
max_val = max(max_val,
|
||||||
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
temp_storage.block_prim.reduce)
|
||||||
|
.Reduce<VEC_SIZE>(in_data_, cub::Max()));
|
||||||
#endif
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@@ -585,10 +655,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u
|
|||||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM>
|
||||||
struct RenormTempStorage {
|
struct RenormTempStorage {
|
||||||
union {
|
union {
|
||||||
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce;
|
typename BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||||
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_int;
|
reduce;
|
||||||
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
|
||||||
reduce_value_count;
|
reduce_int;
|
||||||
|
typename BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>::
|
||||||
|
TempStorage reduce_value_count;
|
||||||
} block_prim;
|
} block_prim;
|
||||||
struct {
|
struct {
|
||||||
float max_val;
|
float max_val;
|
||||||
@@ -607,24 +679,33 @@ struct RenormTempStorage {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
|
template <uint32_t BLOCK_THREADS,
|
||||||
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
|
BlockScanAlgorithm SCAN_ALGORITHM,
|
||||||
typename DType,typename IdType>
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
uint32_t VEC_SIZE,
|
||||||
DType* renormed_prob,uint32_t d) {
|
bool DETERMINISTIC,
|
||||||
|
typename DType,
|
||||||
|
typename IdType>
|
||||||
|
__global__ void MinPSamplingFromProbKernel(DType* probs,
|
||||||
|
const float* min_p_arr,
|
||||||
|
DType* renormed_prob,
|
||||||
|
uint32_t d) {
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
|
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
|
||||||
const uint32_t row_idx = bx;
|
const uint32_t row_idx = bx;
|
||||||
|
|
||||||
extern __shared__ __align__(
|
extern __shared__ __align__(alignof(
|
||||||
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
uint8_t smem_sampling[];
|
uint8_t smem_sampling[];
|
||||||
auto& temp_storage =
|
auto& temp_storage = reinterpret_cast<
|
||||||
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
|
||||||
smem_sampling);
|
smem_sampling);
|
||||||
|
|
||||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
float max_val = GetMaxValue<
|
||||||
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
VEC_SIZE,
|
||||||
|
BLOCK_THREADS,
|
||||||
|
REDUCE_ALGORITHM,
|
||||||
|
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
|
||||||
probs, row_idx, d, temp_storage);
|
probs, row_idx, d, temp_storage);
|
||||||
float pivot = max_val * p;
|
float pivot = max_val * p;
|
||||||
|
|
||||||
@@ -633,7 +714,8 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@@ -641,42 +723,51 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
|
|||||||
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
|
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
|
||||||
}
|
}
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
probs_vec.store(renormed_prob + row_idx * d +
|
||||||
|
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t BLOCK_THREADS,
|
||||||
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
|
BlockReduceAlgorithm REDUCE_ALGORITHM,
|
||||||
typename DType, typename IdType>
|
uint32_t VEC_SIZE,
|
||||||
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
|
typename DType,
|
||||||
|
typename IdType>
|
||||||
|
__global__ void TopKRenormProbKernel(DType* probs,
|
||||||
|
DType* renormed_prob,
|
||||||
|
IdType* top_k_arr,
|
||||||
|
uint32_t d) {
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
const uint32_t row_idx = bx;
|
const uint32_t row_idx = bx;
|
||||||
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx];
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
|
double pivot = std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||||
#else
|
#else
|
||||||
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
double pivot = -cuda::std::numeric_limits<float>::infinity(), normalizer = 1;
|
||||||
#endif
|
#endif
|
||||||
vec_t<float, VEC_SIZE> probs_vec;
|
vec_t<float, VEC_SIZE> probs_vec;
|
||||||
if (k < d) {
|
if (k < d) {
|
||||||
extern __shared__ __align__(alignof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>))
|
extern __shared__ __align__(alignof(
|
||||||
uint8_t smem_renorm[];
|
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>)) uint8_t smem_renorm[];
|
||||||
auto& temp_storage =
|
auto& temp_storage =
|
||||||
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
|
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(
|
||||||
|
smem_renorm);
|
||||||
temp_storage.max_val = 0;
|
temp_storage.max_val = 0;
|
||||||
|
|
||||||
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
|
float max_val =
|
||||||
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
GetMaxValue<VEC_SIZE,
|
||||||
probs, row_idx, d, temp_storage);
|
BLOCK_THREADS,
|
||||||
|
REDUCE_ALGORITHM,
|
||||||
|
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(
|
||||||
|
probs, row_idx, d, temp_storage);
|
||||||
|
|
||||||
double low = 0, high = max_val;
|
double low = 0, high = max_val;
|
||||||
float min_gt_low, max_le_high;
|
float min_gt_low, max_le_high;
|
||||||
float sum_low = 1;
|
float sum_low = 1;
|
||||||
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
// f(x) = len(nonzero(probs > x)), f(x) is non-increasing
|
||||||
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high}
|
// min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
|
||||||
// loop invariant:
|
// p <= high} loop invariant:
|
||||||
// - f(low) >= k, f(high) < k
|
// - f(low) >= k, f(high) < k
|
||||||
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
// - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
|
||||||
// stopping condition: min_gt_low == max_le_high
|
// stopping condition: min_gt_low == max_le_high
|
||||||
@@ -692,55 +783,65 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d +
|
||||||
|
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE];
|
ValueCount<float> probs_gt_pivot_0_pair[VEC_SIZE],
|
||||||
|
probs_gt_pivot_1_pair[VEC_SIZE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
probs_gt_pivot_0_pair[j] = {
|
probs_gt_pivot_0_pair[j] = {
|
||||||
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
(probs_vec[j] > pivot_0) ? probs_vec[j] : 0,
|
||||||
(probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
(probs_vec[j] > pivot_0 &&
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
probs_gt_pivot_1_pair[j] = {
|
probs_gt_pivot_1_pair[j] = {
|
||||||
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
(probs_vec[j] > pivot_1) ? probs_vec[j] : 0,
|
||||||
(probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
(probs_vec[j] > pivot_1 &&
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
|
||||||
|
|
||||||
if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
if (probs_vec[j] > low &&
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||||
min_gt_low = min(min_gt_low, probs_vec[j]);
|
min_gt_low = min(min_gt_low, probs_vec[j]);
|
||||||
}
|
}
|
||||||
if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
if (probs_vec[j] <= high &&
|
||||||
|
(i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
|
||||||
max_le_high = max(max_le_high, probs_vec[j]);
|
max_le_high = max(max_le_high, probs_vec[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_0 +=
|
||||||
temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Sum(probs_gt_pivot_0_pair);
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_0_pair);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_0 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_0 +=
|
||||||
temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_0_pair);
|
||||||
#endif
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#ifdef PADDLE_WITH_COREX
|
#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU)
|
||||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_1 +=
|
||||||
temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Sum(probs_gt_pivot_1_pair);
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum(probs_gt_pivot_1_pair);
|
||||||
#else
|
#else
|
||||||
aggregate_gt_pivot_1 += BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
aggregate_gt_pivot_1 +=
|
||||||
temp_storage.block_prim.reduce_value_count)
|
BlockReduce<ValueCount<float>, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
temp_storage.block_prim.reduce_value_count)
|
||||||
|
.Sum<VEC_SIZE>(probs_gt_pivot_1_pair);
|
||||||
#endif
|
#endif
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
min_gt_low =
|
min_gt_low = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
temp_storage.block_prim.reduce)
|
||||||
.Reduce(min_gt_low, cub::Min());
|
.Reduce(min_gt_low, cub::Min());
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
max_le_high =
|
max_le_high = BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(
|
||||||
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
|
temp_storage.block_prim.reduce)
|
||||||
.Reduce(max_le_high, cub::Max());
|
.Reduce(max_le_high, cub::Max());
|
||||||
if (tx == 0) {
|
if (tx == 0) {
|
||||||
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0;
|
||||||
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1;
|
||||||
@@ -774,23 +875,29 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
|
|||||||
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
|
||||||
probs_vec.fill(0);
|
probs_vec.fill(0);
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
|
||||||
|
tx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
|
||||||
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0;
|
||||||
}
|
}
|
||||||
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
|
||||||
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
probs_vec.store(renormed_prob + row_idx * d +
|
||||||
|
i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
cudaError_t TopPSamplingFromProb(T* probs,
|
||||||
uint32_t batch_size, const T *top_p_val,
|
IdType* output,
|
||||||
uint32_t d, bool deterministic,
|
uint32_t batch_size,
|
||||||
uint64_t philox_seed, uint64_t philox_offset,
|
const T* top_p_val,
|
||||||
|
uint32_t d,
|
||||||
|
bool deterministic,
|
||||||
|
uint64_t philox_seed,
|
||||||
|
uint64_t philox_offset,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||||
@@ -799,99 +906,139 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
|
|||||||
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
void* args[] = {&probs, &output, &top_p_val,
|
void* args[] = {
|
||||||
&d, &philox_seed, &philox_offset};
|
&probs, &output, &top_p_val, &d, &philox_seed, &philox_offset};
|
||||||
|
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
vec_size, VEC_SIZE,
|
vec_size,
|
||||||
|
VEC_SIZE,
|
||||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||||
auto kernel =
|
auto kernel = TopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||||
TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
SCAN_ALGO,
|
||||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
REDUCE_ALGO,
|
||||||
|
VEC_SIZE,
|
||||||
|
DETERMINISTIC,
|
||||||
|
T,
|
||||||
|
IdType>;
|
||||||
CUDA_CALL(cudaFuncSetAttribute(
|
CUDA_CALL(cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
CUDA_CALL(cudaLaunchKernel(
|
||||||
smem_size, stream));
|
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
})});
|
})});
|
||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
|
cudaError_t MinPSamplingFromProb(T* probs,
|
||||||
|
const T* min_p_arr,
|
||||||
|
T* renormed_prob,
|
||||||
uint32_t batch_size,
|
uint32_t batch_size,
|
||||||
uint32_t d, bool deterministic,
|
uint32_t d,
|
||||||
cudaStream_t stream = 0){
|
bool deterministic,
|
||||||
|
cudaStream_t stream = 0) {
|
||||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||||
|
|
||||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
const uint32_t smem_size =
|
||||||
|
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
|
void* args[] = {&probs, &min_p_arr, &renormed_prob, &d};
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
vec_size, VEC_SIZE,
|
vec_size,
|
||||||
|
VEC_SIZE,
|
||||||
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||||
auto kernel =
|
auto kernel = MinPSamplingFromProbKernel<BLOCK_THREADS,
|
||||||
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
SCAN_ALGO,
|
||||||
VEC_SIZE, DETERMINISTIC, T,IdType>;
|
REDUCE_ALGO,
|
||||||
|
VEC_SIZE,
|
||||||
|
DETERMINISTIC,
|
||||||
|
T,
|
||||||
|
IdType>;
|
||||||
CUDA_CALL(cudaFuncSetAttribute(
|
CUDA_CALL(cudaFuncSetAttribute(
|
||||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
|
CUDA_CALL(cudaLaunchKernel(
|
||||||
smem_size, stream));
|
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
})});
|
})});
|
||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename IdType>
|
template <typename T, typename IdType>
|
||||||
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
|
cudaError_t TopKTopPSamplingFromProb(T* probs,
|
||||||
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
|
IdType* output,
|
||||||
uint32_t d, bool deterministic,
|
uint32_t batch_size,
|
||||||
uint64_t philox_seed, uint64_t philox_offset,
|
const T* top_p_val,
|
||||||
|
const IdType* top_k_val,
|
||||||
|
uint32_t d,
|
||||||
|
bool deterministic,
|
||||||
|
uint64_t philox_seed,
|
||||||
|
uint64_t philox_offset,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
|
||||||
|
|
||||||
auto compute_capacity = GetCudaComputeCapability();
|
auto compute_capacity = GetCudaComputeCapability();
|
||||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||||
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
const uint32_t smem_size =
|
||||||
|
sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
void* args[] = {&probs, &output, &top_p_val, &top_k_val,
|
void* args[] = {&probs,
|
||||||
&d, &philox_seed, &philox_offset};
|
&output,
|
||||||
|
&top_p_val,
|
||||||
|
&top_k_val,
|
||||||
|
&d,
|
||||||
|
&philox_seed,
|
||||||
|
&philox_offset};
|
||||||
|
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
vec_size,
|
||||||
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
|
VEC_SIZE,
|
||||||
VEC_SIZE, DETERMINISTIC, T, IdType>;
|
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||||
CUDA_CALL(
|
auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS,
|
||||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
SCAN_ALGO,
|
||||||
CUDA_CALL(
|
REDUCE_ALGO,
|
||||||
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
VEC_SIZE,
|
||||||
|
DETERMINISTIC,
|
||||||
|
T,
|
||||||
|
IdType>;
|
||||||
|
CUDA_CALL(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
|
CUDA_CALL(cudaLaunchKernel(
|
||||||
|
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
})});
|
})});
|
||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DType, typename IdType>
|
template <typename DType, typename IdType>
|
||||||
cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr,
|
cudaError_t TopKRenormProb(DType* probs,
|
||||||
uint32_t batch_size, uint32_t d,
|
DType* renormed_prob,
|
||||||
|
IdType* top_k_arr,
|
||||||
|
uint32_t batch_size,
|
||||||
|
uint32_t d,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||||
|
|
||||||
auto compute_capacity = GetCudaComputeCapability();
|
auto compute_capacity = GetCudaComputeCapability();
|
||||||
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, {
|
||||||
const uint32_t smem_size = sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
const uint32_t smem_size =
|
||||||
|
sizeof(RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
void* args[] = {&probs, &renormed_prob, &top_k_arr, &d};
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
|
||||||
auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType, IdType>;
|
auto kernel = TopKRenormProbKernel<BLOCK_THREADS,
|
||||||
CUDA_CALL(
|
REDUCE_ALGO,
|
||||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
VEC_SIZE,
|
||||||
CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
DType,
|
||||||
|
IdType>;
|
||||||
|
CUDA_CALL(cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
|
CUDA_CALL(cudaLaunchKernel(
|
||||||
|
(void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
});
|
});
|
||||||
return cudaSuccess;
|
return cudaSuccess;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sampling
|
} // namespace sampling
|
||||||
|
|||||||
@@ -23,221 +23,235 @@
|
|||||||
#include <cuda_device_runtime_api.h>
|
#include <cuda_device_runtime_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <curand.h>
|
||||||
|
#include <curand_kernel.h>
|
||||||
|
#include <curand_philox4x32_x.h>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <curand.h>
|
|
||||||
#include <curand_kernel.h>
|
|
||||||
#include <curand_philox4x32_x.h>
|
|
||||||
|
|
||||||
/******************* utils *******************/
|
/******************* utils *******************/
|
||||||
#define STR_HELPER(x) #x
|
#define STR_HELPER(x) #x
|
||||||
#define STR(x) STR_HELPER(x)
|
#define STR(x) STR_HELPER(x)
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
#define CUDA_CALL(func, ...) \
|
#define CUDA_CALL(func, ...) \
|
||||||
{ \
|
{ \
|
||||||
cudaError_t e = (func); \
|
cudaError_t e = (func); \
|
||||||
if (e != cudaSuccess) { \
|
if (e != cudaSuccess) { \
|
||||||
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
|
||||||
<< ") " << __FILE__ << ": line " << __LINE__ \
|
<< ") " << __FILE__ << ": line " << __LINE__ \
|
||||||
<< " at function " << STR(func) << std::endl; \
|
<< " at function " << STR(func) << std::endl; \
|
||||||
return e; \
|
return e; \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
#define CUDA_CALL(func, ...) \
|
#define CUDA_CALL(func, ...) \
|
||||||
{ \
|
{ \
|
||||||
cudaError_t e = (func); \
|
cudaError_t e = (func); \
|
||||||
if (e != cudaSuccess) { \
|
if (e != cudaSuccess) { \
|
||||||
return e; \
|
return e; \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
|
#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
|
||||||
if (deterministic) { \
|
if (deterministic) { \
|
||||||
constexpr bool DETERMINISTIC = true; \
|
constexpr bool DETERMINISTIC = true; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
} else { \
|
} else { \
|
||||||
constexpr bool DETERMINISTIC = false; \
|
constexpr bool DETERMINISTIC = false; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
|
#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
|
||||||
switch (aligned_vec_size) { \
|
switch (aligned_vec_size) { \
|
||||||
case 16: { \
|
case 16: { \
|
||||||
constexpr size_t ALIGNED_VEC_SIZE = 16; \
|
constexpr size_t ALIGNED_VEC_SIZE = 16; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case 8: { \
|
case 8: { \
|
||||||
constexpr size_t ALIGNED_VEC_SIZE = 8; \
|
constexpr size_t ALIGNED_VEC_SIZE = 8; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case 4: { \
|
case 4: { \
|
||||||
constexpr size_t ALIGNED_VEC_SIZE = 4; \
|
constexpr size_t ALIGNED_VEC_SIZE = 4; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case 2: { \
|
case 2: { \
|
||||||
constexpr size_t ALIGNED_VEC_SIZE = 2; \
|
constexpr size_t ALIGNED_VEC_SIZE = 2; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case 1: { \
|
case 1: { \
|
||||||
constexpr size_t ALIGNED_VEC_SIZE = 1; \
|
constexpr size_t ALIGNED_VEC_SIZE = 1; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
default: { \
|
default: { \
|
||||||
std::ostringstream err_msg; \
|
std::ostringstream err_msg; \
|
||||||
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
|
||||||
throw std::invalid_argument(err_msg.str()); \
|
throw std::invalid_argument(err_msg.str()); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
/******************* vec_t<float> *******************/
|
/******************* vec_t<float> *******************/
|
||||||
#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
|
#define SAMPLING_INLINE inline __attribute__((always_inline)) __device__
|
||||||
template <typename float_t, size_t vec_size> struct vec_t {
|
template <typename float_t, size_t vec_size>
|
||||||
SAMPLING_INLINE float_t &operator[](size_t i);
|
struct vec_t {
|
||||||
SAMPLING_INLINE const float_t &operator[](size_t i) const;
|
SAMPLING_INLINE float_t& operator[](size_t i);
|
||||||
|
SAMPLING_INLINE const float_t& operator[](size_t i) const;
|
||||||
SAMPLING_INLINE void fill(float_t val);
|
SAMPLING_INLINE void fill(float_t val);
|
||||||
SAMPLING_INLINE void load(const float_t *ptr);
|
SAMPLING_INLINE void load(const float_t* ptr);
|
||||||
SAMPLING_INLINE void store(float_t *ptr) const;
|
SAMPLING_INLINE void store(float_t* ptr) const;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src);
|
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src);
|
||||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr);
|
template <typename T>
|
||||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const;
|
SAMPLING_INLINE void cast_load(const T* ptr);
|
||||||
SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src);
|
template <typename T>
|
||||||
SAMPLING_INLINE float_t *ptr();
|
SAMPLING_INLINE void cast_store(T* ptr) const;
|
||||||
|
SAMPLING_INLINE static void memcpy(float_t* dst, const float_t* src);
|
||||||
|
SAMPLING_INLINE float_t* ptr();
|
||||||
};
|
};
|
||||||
|
|
||||||
// float x 1
|
// float x 1
|
||||||
template <> struct vec_t<float, 1> {
|
template <>
|
||||||
|
struct vec_t<float, 1> {
|
||||||
float data;
|
float data;
|
||||||
|
|
||||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
|
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
||||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||||
return ((const float *)(&data))[i];
|
return ((const float*)(&data))[i];
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||||
SAMPLING_INLINE void fill(float val);
|
SAMPLING_INLINE void fill(float val);
|
||||||
SAMPLING_INLINE void load(const float *ptr);
|
SAMPLING_INLINE void load(const float* ptr);
|
||||||
SAMPLING_INLINE void store(float *ptr) const;
|
SAMPLING_INLINE void store(float* ptr) const;
|
||||||
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 1> &src) {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_from(const vec_t<T, 1>& src) {
|
||||||
cast_from_impl(*this, src);
|
cast_from_impl(*this, src);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||||
cast_load_impl(*this, ptr);
|
cast_load_impl(*this, ptr);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||||
cast_store_impl(ptr, *this);
|
cast_store_impl(ptr, *this);
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
|
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
|
||||||
};
|
};
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
|
SAMPLING_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 1>::load(const float *ptr) { data = *ptr; }
|
SAMPLING_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 1>::store(float *ptr) const { *ptr = data; }
|
SAMPLING_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float *dst, const float *src) {
|
SAMPLING_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
|
||||||
*dst = *src;
|
*dst = *src;
|
||||||
}
|
}
|
||||||
|
|
||||||
// float x 2
|
// float x 2
|
||||||
template <> struct vec_t<float, 2> {
|
template <>
|
||||||
|
struct vec_t<float, 2> {
|
||||||
float2 data;
|
float2 data;
|
||||||
|
|
||||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; }
|
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
||||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||||
return ((const float *)(&data))[i];
|
return ((const float*)(&data))[i];
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||||
SAMPLING_INLINE void fill(float val);
|
SAMPLING_INLINE void fill(float val);
|
||||||
SAMPLING_INLINE void load(const float *ptr);
|
SAMPLING_INLINE void load(const float* ptr);
|
||||||
SAMPLING_INLINE void store(float *ptr) const;
|
SAMPLING_INLINE void store(float* ptr) const;
|
||||||
template <typename T> SAMPLING_INLINE void cast_from(const vec_t<T, 2> &src) {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_from(const vec_t<T, 2>& src) {
|
||||||
cast_from_impl(*this, src);
|
cast_from_impl(*this, src);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||||
cast_load_impl(*this, ptr);
|
cast_load_impl(*this, ptr);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||||
cast_store_impl(ptr, *this);
|
cast_store_impl(ptr, *this);
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src);
|
SAMPLING_INLINE static void memcpy(float* dst, const float* src);
|
||||||
};
|
};
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
|
SAMPLING_INLINE void vec_t<float, 2>::fill(float val) {
|
||||||
data = make_float2(val, val);
|
data = make_float2(val, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 2>::load(const float *ptr) {
|
SAMPLING_INLINE void vec_t<float, 2>::load(const float* ptr) {
|
||||||
data = *((float2 *)ptr);
|
data = *((float2*)ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 2>::store(float *ptr) const {
|
SAMPLING_INLINE void vec_t<float, 2>::store(float* ptr) const {
|
||||||
*((float2 *)ptr) = data;
|
*((float2*)ptr) = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float *dst, const float *src) {
|
SAMPLING_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
|
||||||
*((float2 *)dst) = *((float2 *)src);
|
*((float2*)dst) = *((float2*)src);
|
||||||
}
|
}
|
||||||
|
|
||||||
// float x 4 or more
|
// float x 4 or more
|
||||||
template <size_t vec_size> struct vec_t<float, vec_size> {
|
template <size_t vec_size>
|
||||||
|
struct vec_t<float, vec_size> {
|
||||||
float4 data[vec_size / 4];
|
float4 data[vec_size / 4];
|
||||||
|
|
||||||
SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; }
|
SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
|
||||||
SAMPLING_INLINE const float &operator[](size_t i) const {
|
SAMPLING_INLINE const float& operator[](size_t i) const {
|
||||||
return ((const float *)(data))[i];
|
return ((const float*)(data))[i];
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE float *ptr() { return reinterpret_cast<float *>(&data); }
|
SAMPLING_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
||||||
SAMPLING_INLINE void fill(float val) {
|
SAMPLING_INLINE void fill(float val) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||||
data[i] = make_float4(val, val, val, val);
|
data[i] = make_float4(val, val, val, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE void load(const float *ptr) {
|
SAMPLING_INLINE void load(const float* ptr) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||||
data[i] = ((float4 *)ptr)[i];
|
data[i] = ((float4*)ptr)[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE void store(float *ptr) const {
|
SAMPLING_INLINE void store(float* ptr) const {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||||
((float4 *)ptr)[i] = data[i];
|
((float4*)ptr)[i] = data[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size> &src) {
|
SAMPLING_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
||||||
cast_from_impl(*this, src);
|
cast_from_impl(*this, src);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_load(const T *ptr) {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_load(const T* ptr) {
|
||||||
cast_load_impl(*this, ptr);
|
cast_load_impl(*this, ptr);
|
||||||
}
|
}
|
||||||
template <typename T> SAMPLING_INLINE void cast_store(T *ptr) const {
|
template <typename T>
|
||||||
|
SAMPLING_INLINE void cast_store(T* ptr) const {
|
||||||
cast_store_impl(ptr, *this);
|
cast_store_impl(ptr, *this);
|
||||||
}
|
}
|
||||||
SAMPLING_INLINE static void memcpy(float *dst, const float *src) {
|
SAMPLING_INLINE static void memcpy(float* dst, const float* src) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < vec_size / 4; ++i) {
|
for (size_t i = 0; i < vec_size / 4; ++i) {
|
||||||
((float4 *)dst)[i] = ((float4 *)src)[i];
|
((float4*)dst)[i] = ((float4*)src)[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
||||||
SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
|
SAMPLING_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
|
||||||
const src_float_t* src_ptr) {
|
const src_float_t* src_ptr) {
|
||||||
if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
|
if constexpr (std::is_same_v<src_float_t, tgt_float_t>) {
|
||||||
dst.load(src_ptr);
|
dst.load(src_ptr);
|
||||||
} else {
|
} else {
|
||||||
@@ -260,11 +274,16 @@ inline std::pair<int, int> GetCudaComputeCapability() {
|
|||||||
__forceinline__ __device__ float ptx_rcp(float x) {
|
__forceinline__ __device__ float ptx_rcp(float x) {
|
||||||
#ifdef PADDLE_WITH_COREX
|
#ifdef PADDLE_WITH_COREX
|
||||||
return __ivcorex_rcpf(x);
|
return __ivcorex_rcpf(x);
|
||||||
|
#else
|
||||||
|
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
|
||||||
|
return __frcp_rn(x);
|
||||||
#else
|
#else
|
||||||
float y;
|
float y;
|
||||||
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
|
||||||
return y;
|
return y;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
|
|||||||
@@ -1,291 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <paddle/extension.h>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "helper.h"
|
|
||||||
|
|
||||||
#define THREADS_PER_BLOCK 128
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct Converter;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Converter<__half> {
|
|
||||||
// __half -> float
|
|
||||||
__device__ static float to_float(__half val) { return __half2float(val); }
|
|
||||||
// float -> __half
|
|
||||||
__device__ static __half from_float(float val) {
|
|
||||||
return __float2half_rn(val);
|
|
||||||
}
|
|
||||||
// int -> __half
|
|
||||||
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Converter<__nv_bfloat16> {
|
|
||||||
// __nv_bfloat16 -> float
|
|
||||||
__device__ static float to_float(__nv_bfloat16 val) {
|
|
||||||
return __bfloat162float(val);
|
|
||||||
}
|
|
||||||
// float -> __nv_bfloat16
|
|
||||||
__device__ static __nv_bfloat16 from_float(float val) {
|
|
||||||
return __float2bfloat16_rn(val);
|
|
||||||
}
|
|
||||||
// int -> __nv_bfloat16
|
|
||||||
__device__ static __nv_bfloat16 from_int(int val) {
|
|
||||||
return __int2bfloat16_rn(val);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
|
||||||
const T* rot_cos_ptr,
|
|
||||||
const T* rot_sin_ptr,
|
|
||||||
const int head_num,
|
|
||||||
const int base_idx,
|
|
||||||
const int rot_base_idx,
|
|
||||||
T* out) {
|
|
||||||
using VecT = AlignedVector<T, 4>;
|
|
||||||
|
|
||||||
VecT qk_vec;
|
|
||||||
Load(qk_ptr + base_idx, &qk_vec);
|
|
||||||
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
|
|
||||||
VecT cos_vec, sin_vec;
|
|
||||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
|
||||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
*(out + base_idx + i) =
|
|
||||||
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ void RotateQKVec4(const T* qk_ptr,
|
|
||||||
const float* rot_cos_ptr,
|
|
||||||
const float* rot_sin_ptr,
|
|
||||||
const int head_num,
|
|
||||||
const int base_idx,
|
|
||||||
const int rot_base_idx,
|
|
||||||
T* out) {
|
|
||||||
using VecT = AlignedVector<T, 4>;
|
|
||||||
using VecF = AlignedVector<float, 4>;
|
|
||||||
auto to_float = [] __device__(T val) -> float {
|
|
||||||
return Converter<T>::to_float(val);
|
|
||||||
};
|
|
||||||
auto from_float = [] __device__(float val) -> T {
|
|
||||||
return Converter<T>::from_float(val);
|
|
||||||
};
|
|
||||||
|
|
||||||
VecT qk_vec;
|
|
||||||
Load(qk_ptr + base_idx, &qk_vec);
|
|
||||||
VecF rot_half_vec = {-to_float(qk_vec[1]),
|
|
||||||
to_float(qk_vec[0]),
|
|
||||||
-to_float(qk_vec[3]),
|
|
||||||
to_float(qk_vec[2])};
|
|
||||||
VecF cos_vec, sin_vec;
|
|
||||||
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
|
||||||
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
|
||||||
*(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
|
||||||
rot_half_vec[i] * sin_vec[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// qk and rope have a same type
|
|
||||||
template <typename T>
|
|
||||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
|
||||||
const T* k,
|
|
||||||
const T* rot_cos,
|
|
||||||
const T* rot_sin,
|
|
||||||
const int q_num_elements,
|
|
||||||
const int k_num_elements,
|
|
||||||
const int q_head_num,
|
|
||||||
const int k_head_num,
|
|
||||||
const int head_dim,
|
|
||||||
T* q_out,
|
|
||||||
T* k_out) {
|
|
||||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
|
||||||
int head_dim_idx = idx % head_dim;
|
|
||||||
|
|
||||||
if (idx < q_num_elements) {
|
|
||||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
|
||||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (idx < k_num_elements) {
|
|
||||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
|
||||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// rope dtype is float32
|
|
||||||
template <typename T>
|
|
||||||
__global__ void DispatchApplyRopeVec4Kernel(const T* q,
|
|
||||||
const T* k,
|
|
||||||
const float* rot_cos,
|
|
||||||
const float* rot_sin,
|
|
||||||
const int q_num_elements,
|
|
||||||
const int k_num_elements,
|
|
||||||
const int q_head_num,
|
|
||||||
const int k_head_num,
|
|
||||||
const int head_dim,
|
|
||||||
T* q_out,
|
|
||||||
T* k_out) {
|
|
||||||
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
|
|
||||||
int head_dim_idx = idx % head_dim;
|
|
||||||
|
|
||||||
if (idx < q_num_elements) {
|
|
||||||
int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx;
|
|
||||||
RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (idx < k_num_elements) {
|
|
||||||
int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx;
|
|
||||||
RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <paddle::DataType D>
|
|
||||||
void ApplyRopeKernel(const paddle::Tensor& q,
|
|
||||||
const paddle::Tensor& k,
|
|
||||||
const paddle::Tensor& rot_cos,
|
|
||||||
const paddle::Tensor& rot_sin,
|
|
||||||
paddle::Tensor& q_out,
|
|
||||||
paddle::Tensor& k_out) {
|
|
||||||
typedef PDTraits<D> traits_;
|
|
||||||
typedef typename traits_::DataType DataType_;
|
|
||||||
typedef typename traits_::data_t data_t;
|
|
||||||
|
|
||||||
const auto q_num_elements = q.numel();
|
|
||||||
const auto k_num_elements = k.numel();
|
|
||||||
const auto q_shape = q.shape();
|
|
||||||
const auto k_shape = k.shape();
|
|
||||||
const auto dims = q_shape.size();
|
|
||||||
const auto q_head_num = q_shape[dims - 2];
|
|
||||||
const auto k_head_num = k_shape[dims - 2];
|
|
||||||
const auto head_dim = q_shape.back();
|
|
||||||
int block_num =
|
|
||||||
(std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) /
|
|
||||||
(THREADS_PER_BLOCK * 4);
|
|
||||||
auto stream = q.stream();
|
|
||||||
|
|
||||||
if (q.dtype() == rot_cos.dtype()) {
|
|
||||||
DispatchApplyRopeVec4Kernel<DataType_>
|
|
||||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
|
||||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
|
||||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
|
||||||
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
|
|
||||||
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
|
|
||||||
q_num_elements,
|
|
||||||
k_num_elements,
|
|
||||||
q_head_num,
|
|
||||||
k_head_num,
|
|
||||||
head_dim,
|
|
||||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
|
||||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
|
||||||
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
|
|
||||||
DispatchApplyRopeVec4Kernel<DataType_>
|
|
||||||
<<<block_num, THREADS_PER_BLOCK, 0, stream>>>(
|
|
||||||
reinterpret_cast<const DataType_*>(q.data<data_t>()),
|
|
||||||
reinterpret_cast<const DataType_*>(k.data<data_t>()),
|
|
||||||
reinterpret_cast<const float*>(rot_cos.data<float>()),
|
|
||||||
reinterpret_cast<const float*>(rot_sin.data<float>()),
|
|
||||||
q_num_elements,
|
|
||||||
k_num_elements,
|
|
||||||
q_head_num,
|
|
||||||
k_head_num,
|
|
||||||
head_dim,
|
|
||||||
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
|
||||||
reinterpret_cast<DataType_*>(k_out.data<data_t>()));
|
|
||||||
} else {
|
|
||||||
PD_THROW("Unsupported qk dtype and rope dtype.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<paddle::Tensor> ApplyRope(const paddle::Tensor& q,
|
|
||||||
const paddle::Tensor& k,
|
|
||||||
const paddle::Tensor& rot_cos,
|
|
||||||
const paddle::Tensor& rot_sin) {
|
|
||||||
auto q_shape = q.shape();
|
|
||||||
auto cos_shape = rot_cos.shape();
|
|
||||||
|
|
||||||
auto q_out = paddle::empty_like(q);
|
|
||||||
auto k_out = paddle::empty_like(k);
|
|
||||||
|
|
||||||
if (q.numel() == 0 || k.numel() == 0) {
|
|
||||||
return {q_out, k_out};
|
|
||||||
}
|
|
||||||
|
|
||||||
PADDLE_ENFORCE_EQ(
|
|
||||||
q_shape.back() % 2,
|
|
||||||
0,
|
|
||||||
"The last dimension (head_dim) of qk must be an even number "
|
|
||||||
"for RoPE, but got %d",
|
|
||||||
q_shape.back());
|
|
||||||
PADDLE_ENFORCE_EQ(q_shape.size(),
|
|
||||||
cos_shape.size(),
|
|
||||||
"The shape size of cos mismatches the shape size of q, "
|
|
||||||
"expect %d but got %d",
|
|
||||||
q_shape.size(),
|
|
||||||
cos_shape.size());
|
|
||||||
PADDLE_ENFORCE_EQ(q_shape.back(),
|
|
||||||
cos_shape.back(),
|
|
||||||
"The shape.back() of cos mismatches the shape.back() of q, "
|
|
||||||
"expect %d but got %d",
|
|
||||||
q_shape.back(),
|
|
||||||
cos_shape.back());
|
|
||||||
|
|
||||||
auto input_type = q.dtype();
|
|
||||||
switch (input_type) {
|
|
||||||
case paddle::DataType::BFLOAT16:
|
|
||||||
ApplyRopeKernel<paddle::DataType::BFLOAT16>(
|
|
||||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
|
||||||
break;
|
|
||||||
case paddle::DataType::FLOAT16:
|
|
||||||
ApplyRopeKernel<paddle::DataType::FLOAT16>(
|
|
||||||
q, k, rot_cos, rot_sin, q_out, k_out);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
PD_THROW("Only support qk dtype of BF16 and F16");
|
|
||||||
}
|
|
||||||
|
|
||||||
return {q_out, k_out};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ApplyRopeInferShape(
|
|
||||||
const std::vector<int64_t>& q_shape,
|
|
||||||
const std::vector<int64_t>& k_shape,
|
|
||||||
const std::vector<int64_t>& cos_shape,
|
|
||||||
const std::vector<int64_t>& sin_shape) {
|
|
||||||
return {q_shape, k_shape, cos_shape, sin_shape};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<paddle::DataType> ApplyRopeInferDtype(
|
|
||||||
const paddle::DataType& q_dtype,
|
|
||||||
const paddle::DataType& k_dtype,
|
|
||||||
const paddle::DataType& cos_dtype,
|
|
||||||
const paddle::DataType& sin_dtype) {
|
|
||||||
return {q_dtype, k_dtype, cos_dtype, sin_dtype};
|
|
||||||
}
|
|
||||||
|
|
||||||
PD_BUILD_OP(apply_rope)
|
|
||||||
.Inputs({"q", "k", "rot_cos", "rot_sin"})
|
|
||||||
.Outputs({"q_out", "k_out"})
|
|
||||||
.SetKernelFn(PD_KERNEL(ApplyRope))
|
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape))
|
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype));
|
|
||||||
329
custom_ops/metax_ops/apply_rope_qkv.cu
Normal file
329
custom_ops/metax_ops/apply_rope_qkv.cu
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <paddle/extension.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Converter;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Converter<__half> {
|
||||||
|
// __half -> float
|
||||||
|
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||||
|
// float -> __half
|
||||||
|
__device__ static __half from_float(float val) {
|
||||||
|
return __float2half_rn(val);
|
||||||
|
}
|
||||||
|
// int -> __half
|
||||||
|
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Converter<__nv_bfloat16> {
|
||||||
|
// __nv_bfloat16 -> float
|
||||||
|
__device__ static float to_float(__nv_bfloat16 val) {
|
||||||
|
return __bfloat162float(val);
|
||||||
|
}
|
||||||
|
// float -> __nv_bfloat16
|
||||||
|
__device__ static __nv_bfloat16 from_float(float val) {
|
||||||
|
return __float2bfloat16_rn(val);
|
||||||
|
}
|
||||||
|
// int -> __nv_bfloat16
|
||||||
|
__device__ static __nv_bfloat16 from_int(int val) {
|
||||||
|
return __int2bfloat16_rn(val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ApplyRopeQKVParams {
|
||||||
|
int head_dim;
|
||||||
|
int token_stride;
|
||||||
|
int head_stride;
|
||||||
|
int q_stride;
|
||||||
|
int kv_stride;
|
||||||
|
int q_head_offset;
|
||||||
|
int k_head_offset;
|
||||||
|
int v_head_offset;
|
||||||
|
int q_head_num;
|
||||||
|
int kv_head_num;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
|
||||||
|
const T* rot_cos_ptr,
|
||||||
|
const T* rot_sin_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
const int rot_base_idx,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, 4>;
|
||||||
|
|
||||||
|
VecT qk_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &qk_vec);
|
||||||
|
VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]};
|
||||||
|
VecT cos_vec, sin_vec;
|
||||||
|
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||||
|
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
*(out + store_idx + i) =
|
||||||
|
qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr,
|
||||||
|
const float* rot_cos_ptr,
|
||||||
|
const float* rot_sin_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
const int rot_base_idx,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, 4>;
|
||||||
|
using VecF = AlignedVector<float, 4>;
|
||||||
|
auto to_float = [] __device__(T val) -> float {
|
||||||
|
return Converter<T>::to_float(val);
|
||||||
|
};
|
||||||
|
auto from_float = [] __device__(float val) -> T {
|
||||||
|
return Converter<T>::from_float(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
VecT qk_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &qk_vec);
|
||||||
|
VecF rot_half_vec = {-to_float(qk_vec[1]),
|
||||||
|
to_float(qk_vec[0]),
|
||||||
|
-to_float(qk_vec[3]),
|
||||||
|
to_float(qk_vec[2])};
|
||||||
|
VecF cos_vec, sin_vec;
|
||||||
|
Load(rot_cos_ptr + rot_base_idx, &cos_vec);
|
||||||
|
Load(rot_sin_ptr + rot_base_idx, &sin_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
*(out + store_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||||
|
rot_half_vec[i] * sin_vec[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, 4>;
|
||||||
|
VecT v_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &v_vec);
|
||||||
|
Store(v_vec, out + store_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename WeightType>
|
||||||
|
__global__ void DispatchApplyRopeQKVVec4Kernel(const T* qkv,
|
||||||
|
const WeightType* rot_cos,
|
||||||
|
const WeightType* rot_sin,
|
||||||
|
ApplyRopeQKVParams param,
|
||||||
|
T* q_out,
|
||||||
|
T* k_out,
|
||||||
|
T* v_out) {
|
||||||
|
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * 4;
|
||||||
|
int rot_idx = token_idx * param.head_dim + head_dim_idx;
|
||||||
|
int load_idx, store_idx;
|
||||||
|
|
||||||
|
if (head_idx < param.q_head_num && head_dim_idx < param.head_dim) { // q
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.q_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
store_idx =
|
||||||
|
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
|
||||||
|
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, q_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (head_idx < param.kv_head_num && head_dim_idx < param.head_dim) { // kv
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.k_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
store_idx =
|
||||||
|
token_idx * param.kv_stride + head_idx * param.head_dim + head_dim_idx;
|
||||||
|
RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, k_out);
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.v_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
StoreValue(qkv, load_idx, store_idx, v_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <paddle::DataType D>
|
||||||
|
void ApplyRopeQKVKernel(const paddle::Tensor& qkv,
|
||||||
|
const paddle::Tensor& rot_cos,
|
||||||
|
const paddle::Tensor& rot_sin,
|
||||||
|
const int q_head_num,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int head_dim,
|
||||||
|
paddle::Tensor& q_out,
|
||||||
|
paddle::Tensor& k_out,
|
||||||
|
paddle::Tensor& v_out) {
|
||||||
|
typedef PDTraits<D> traits_;
|
||||||
|
typedef typename traits_::DataType DataType_;
|
||||||
|
typedef typename traits_::data_t data_t;
|
||||||
|
|
||||||
|
const int all_num_elements = qkv.numel();
|
||||||
|
const int all_num_head = q_head_num + 2 * kv_head_num;
|
||||||
|
auto stream = qkv.stream();
|
||||||
|
|
||||||
|
dim3 block_dims(1, 4, 32);
|
||||||
|
dim3 grid_dims(all_num_elements / (all_num_head * head_dim), // token
|
||||||
|
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
|
||||||
|
block_dims.y, // head
|
||||||
|
(head_dim + (block_dims.z * 4) - 1) /
|
||||||
|
(block_dims.z * 4) // dim: load vec4 at a time
|
||||||
|
);
|
||||||
|
|
||||||
|
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
|
||||||
|
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
|
||||||
|
|
||||||
|
ApplyRopeQKVParams param;
|
||||||
|
param.head_dim = head_dim;
|
||||||
|
param.token_stride = all_num_head * head_dim;
|
||||||
|
param.head_stride = head_dim;
|
||||||
|
param.q_stride = q_head_num * head_dim;
|
||||||
|
param.kv_stride = kv_head_num * head_dim;
|
||||||
|
param.q_head_offset = 0;
|
||||||
|
param.k_head_offset = q_head_num;
|
||||||
|
param.v_head_offset = q_head_num + kv_head_num;
|
||||||
|
param.q_head_num = q_head_num;
|
||||||
|
param.kv_head_num = kv_head_num;
|
||||||
|
|
||||||
|
if (qkv.dtype() == rot_cos.dtype()) {
|
||||||
|
DispatchApplyRopeQKVVec4Kernel<DataType_, DataType_>
|
||||||
|
<<<grid_dims, block_dims, 0, stream>>>(
|
||||||
|
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||||
|
reinterpret_cast<const DataType_*>(rot_cos.data<data_t>()),
|
||||||
|
reinterpret_cast<const DataType_*>(rot_sin.data<data_t>()),
|
||||||
|
param,
|
||||||
|
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||||
|
} else if (rot_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||||
|
DispatchApplyRopeQKVVec4Kernel<DataType_, float>
|
||||||
|
<<<grid_dims, block_dims, 0, stream>>>(
|
||||||
|
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||||
|
reinterpret_cast<const float*>(rot_cos.data<float>()),
|
||||||
|
reinterpret_cast<const float*>(rot_sin.data<float>()),
|
||||||
|
param,
|
||||||
|
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||||
|
} else {
|
||||||
|
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> ApplyRopeQKV(const paddle::Tensor& qkv,
|
||||||
|
const paddle::Tensor& rot_cos,
|
||||||
|
const paddle::Tensor& rot_sin,
|
||||||
|
const int q_head_num,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int head_dim) {
|
||||||
|
auto qkv_shape = qkv.shape();
|
||||||
|
auto token_num = qkv_shape[0];
|
||||||
|
auto place = qkv.place();
|
||||||
|
auto dtype = qkv.dtype();
|
||||||
|
common::DDim q_out_shape, kv_out_shape;
|
||||||
|
if (rot_cos.shape().size() == 3) {
|
||||||
|
q_out_shape = {token_num, q_head_num, head_dim};
|
||||||
|
kv_out_shape = {token_num, kv_head_num, head_dim};
|
||||||
|
} else {
|
||||||
|
q_out_shape = {token_num, 1, q_head_num, head_dim};
|
||||||
|
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
|
||||||
|
}
|
||||||
|
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
|
||||||
|
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||||
|
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||||
|
|
||||||
|
if (token_num == 0) {
|
||||||
|
return {q_out, k_out, v_out};
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(qkv_shape.back(),
|
||||||
|
((q_head_num + 2 * kv_head_num) * head_dim),
|
||||||
|
"The last dimension of qkv [%d] must equal to {(q_head_num "
|
||||||
|
"+ 2 * kv_head_num) * head_dim [%d].",
|
||||||
|
qkv_shape.back(),
|
||||||
|
((q_head_num + 2 * kv_head_num) * head_dim));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
head_dim % 2,
|
||||||
|
0,
|
||||||
|
"The last dimension (head_dim) of qkv must be an even number "
|
||||||
|
"for RoPE, but got %d",
|
||||||
|
head_dim);
|
||||||
|
PADDLE_ENFORCE_EQ(q_out.shape().back(),
|
||||||
|
rot_cos.shape().back(),
|
||||||
|
"The last dimension of cos mismatches that of q, "
|
||||||
|
"expect %d but got %d",
|
||||||
|
q_out.shape().back(),
|
||||||
|
rot_cos.shape().back());
|
||||||
|
|
||||||
|
switch (dtype) {
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
ApplyRopeQKVKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||||
|
rot_cos,
|
||||||
|
rot_sin,
|
||||||
|
q_head_num,
|
||||||
|
kv_head_num,
|
||||||
|
head_dim,
|
||||||
|
q_out,
|
||||||
|
k_out,
|
||||||
|
v_out);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
ApplyRopeQKVKernel<paddle::DataType::FLOAT16>(qkv,
|
||||||
|
rot_cos,
|
||||||
|
rot_sin,
|
||||||
|
q_head_num,
|
||||||
|
kv_head_num,
|
||||||
|
head_dim,
|
||||||
|
q_out,
|
||||||
|
k_out,
|
||||||
|
v_out);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||||
|
}
|
||||||
|
|
||||||
|
return {q_out, k_out, v_out};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> ApplyRopeQKVInferShape(
|
||||||
|
const std::vector<int64_t>& qkv_shape,
|
||||||
|
const std::vector<int64_t>& cos_shape,
|
||||||
|
const std::vector<int64_t>& sin_shape) {
|
||||||
|
return {qkv_shape, cos_shape, sin_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> ApplyRopeQKVInferDtype(
|
||||||
|
const paddle::DataType& qkv_dtype,
|
||||||
|
const paddle::DataType& cos_dtype,
|
||||||
|
const paddle::DataType& sin_dtype) {
|
||||||
|
return {qkv_dtype, cos_dtype, sin_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_OP(apply_rope_qkv)
|
||||||
|
.Inputs({"qkv", "rot_cos", "rot_sin"})
|
||||||
|
.Outputs({"q_out", "k_out", "v_out"})
|
||||||
|
.Attrs({"q_head_num:int", "kv_head_num:int", "head_dim:int"})
|
||||||
|
.SetKernelFn(PD_KERNEL(ApplyRopeQKV))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeQKVInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeQKVInferDtype));
|
||||||
477
custom_ops/metax_ops/cache_kv_with_rope.cu
Normal file
477
custom_ops/metax_ops/cache_kv_with_rope.cu
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <paddle/extension.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "helper.h"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct Converter;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Converter<__half> {
|
||||||
|
// __half -> float
|
||||||
|
__device__ static float to_float(__half val) { return __half2float(val); }
|
||||||
|
// float -> __half
|
||||||
|
__device__ static __half from_float(float val) {
|
||||||
|
return __float2half_rn(val);
|
||||||
|
}
|
||||||
|
// int -> __half
|
||||||
|
__device__ static __half from_int(float val) { return __int2half_rn(val); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Converter<__nv_bfloat16> {
|
||||||
|
// __nv_bfloat16 -> float
|
||||||
|
__device__ static float to_float(__nv_bfloat16 val) {
|
||||||
|
return __bfloat162float(val);
|
||||||
|
}
|
||||||
|
// float -> __nv_bfloat16
|
||||||
|
__device__ static __nv_bfloat16 from_float(float val) {
|
||||||
|
return __float2bfloat16_rn(val);
|
||||||
|
}
|
||||||
|
// int -> __nv_bfloat16
|
||||||
|
__device__ static __nv_bfloat16 from_int(int val) {
|
||||||
|
return __int2bfloat16_rn(val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CacheKVWithRopeParams {
|
||||||
|
int head_dim;
|
||||||
|
int block_size;
|
||||||
|
int block_num;
|
||||||
|
int cache_stride;
|
||||||
|
int token_stride;
|
||||||
|
int head_stride;
|
||||||
|
int q_stride;
|
||||||
|
int kv_stride;
|
||||||
|
int q_head_offset;
|
||||||
|
int k_head_offset;
|
||||||
|
int v_head_offset;
|
||||||
|
int q_head_num;
|
||||||
|
int kv_head_num;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, int VecSize = 4, bool WriteCache = true>
|
||||||
|
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
|
||||||
|
const T* rotary_cos_ptr,
|
||||||
|
const T* rotary_sin_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
const int cache_store_idx,
|
||||||
|
const int rot_base_idx,
|
||||||
|
T* caches,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, VecSize>;
|
||||||
|
|
||||||
|
VecT qk_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &qk_vec);
|
||||||
|
VecT rot_half_vec;
|
||||||
|
int flag;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; ++i) {
|
||||||
|
flag = 1 - 2 * (i % 2);
|
||||||
|
rot_half_vec[i] = -qk_vec[i + flag] * Converter<T>::from_int(flag);
|
||||||
|
}
|
||||||
|
VecT cos_vec, sin_vec;
|
||||||
|
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
|
||||||
|
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; ++i) {
|
||||||
|
T result = qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i];
|
||||||
|
*(out + store_idx + i) = result;
|
||||||
|
|
||||||
|
if (WriteCache) {
|
||||||
|
*(caches + cache_store_idx + i) = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int VecSize = 4, bool WriteCache = true>
|
||||||
|
__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr,
|
||||||
|
const float* rotary_cos_ptr,
|
||||||
|
const float* rotary_sin_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
const int cache_store_idx,
|
||||||
|
const int rot_base_idx,
|
||||||
|
T* caches,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, VecSize>;
|
||||||
|
using VecF = AlignedVector<float, VecSize>;
|
||||||
|
auto to_float = [] __device__(T val) -> float {
|
||||||
|
return Converter<T>::to_float(val);
|
||||||
|
};
|
||||||
|
auto from_float = [] __device__(float val) -> T {
|
||||||
|
return Converter<T>::from_float(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
VecT qk_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &qk_vec);
|
||||||
|
VecF rot_half_vec;
|
||||||
|
int flag;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; ++i) {
|
||||||
|
flag = 1 - 2 * (i % 2);
|
||||||
|
rot_half_vec[i] = -to_float(qk_vec[i + flag]) * static_cast<float>(flag);
|
||||||
|
}
|
||||||
|
VecF cos_vec, sin_vec;
|
||||||
|
Load(rotary_cos_ptr + rot_base_idx, &cos_vec);
|
||||||
|
Load(rotary_sin_ptr + rot_base_idx, &sin_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; ++i) {
|
||||||
|
T result = from_float(to_float(qk_vec[i]) * cos_vec[i] +
|
||||||
|
rot_half_vec[i] * sin_vec[i]);
|
||||||
|
*(out + store_idx + i) = result;
|
||||||
|
if (WriteCache) {
|
||||||
|
*(caches + cache_store_idx + i) = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int VecSize = 4>
|
||||||
|
__device__ __forceinline__ void StoreValue(const T* qkv_ptr,
|
||||||
|
const int load_idx,
|
||||||
|
const int store_idx,
|
||||||
|
const int cache_store_idx,
|
||||||
|
T* caches,
|
||||||
|
T* out) {
|
||||||
|
using VecT = AlignedVector<T, VecSize>;
|
||||||
|
VecT v_vec;
|
||||||
|
Load(qkv_ptr + load_idx, &v_vec);
|
||||||
|
Store(v_vec, out + store_idx);
|
||||||
|
Store(v_vec, caches + cache_store_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename WeightType, int VecSize>
|
||||||
|
__global__ void DispatchCacheKVWithRopeVecKernel(const T* qkv,
|
||||||
|
T* caches_k,
|
||||||
|
T* caches_v,
|
||||||
|
const int* block_tables,
|
||||||
|
const WeightType* rotary_cos,
|
||||||
|
const WeightType* rotary_sin,
|
||||||
|
const int* cu_seqlens_q,
|
||||||
|
const int* batch_ids_q,
|
||||||
|
CacheKVWithRopeParams param,
|
||||||
|
T* q_out,
|
||||||
|
T* k_out,
|
||||||
|
T* v_out) {
|
||||||
|
const int token_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const int head_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||||
|
const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * VecSize;
|
||||||
|
|
||||||
|
int load_idx, store_idx, cache_store_idx;
|
||||||
|
int rot_idx = token_idx * param.head_dim + head_dim_idx;
|
||||||
|
|
||||||
|
const int batch_idx = *(batch_ids_q + token_idx);
|
||||||
|
const int inter_batch_token_offset = token_idx - *(cu_seqlens_q + batch_idx);
|
||||||
|
const int inter_batch_block_idx = inter_batch_token_offset / param.block_size;
|
||||||
|
const int inter_block_offset = inter_batch_token_offset % param.block_size;
|
||||||
|
const int block_idx =
|
||||||
|
*(block_tables + batch_idx * param.block_num + inter_batch_block_idx);
|
||||||
|
|
||||||
|
assert(block_idx != -1);
|
||||||
|
|
||||||
|
if (head_dim_idx < param.head_dim) {
|
||||||
|
if (head_idx < param.q_head_num) { // q
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.q_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
store_idx =
|
||||||
|
token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx;
|
||||||
|
RotateQKVec<T, VecSize, false>(qkv,
|
||||||
|
rotary_cos,
|
||||||
|
rotary_sin,
|
||||||
|
load_idx,
|
||||||
|
store_idx,
|
||||||
|
-1,
|
||||||
|
rot_idx,
|
||||||
|
static_cast<T*>(nullptr),
|
||||||
|
q_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (head_idx < param.kv_head_num) { // kv
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.k_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
store_idx = token_idx * param.kv_stride + head_idx * param.head_dim +
|
||||||
|
head_dim_idx;
|
||||||
|
cache_store_idx = block_idx * param.cache_stride +
|
||||||
|
inter_block_offset * param.kv_stride +
|
||||||
|
head_idx * param.head_dim + head_dim_idx;
|
||||||
|
// printf("block_idx: %d inter_block_offset: %d cache_store_idx: %d
|
||||||
|
// param.cache_stride: %d\n", block_idx, inter_block_offset,
|
||||||
|
// cache_store_idx, param.cache_stride);
|
||||||
|
RotateQKVec<T, VecSize, true>(qkv,
|
||||||
|
rotary_cos,
|
||||||
|
rotary_sin,
|
||||||
|
load_idx,
|
||||||
|
store_idx,
|
||||||
|
cache_store_idx,
|
||||||
|
rot_idx,
|
||||||
|
caches_k,
|
||||||
|
k_out);
|
||||||
|
|
||||||
|
load_idx = token_idx * param.token_stride +
|
||||||
|
(head_idx + param.v_head_offset) * param.head_stride +
|
||||||
|
head_dim_idx;
|
||||||
|
StoreValue<T, VecSize>(
|
||||||
|
qkv, load_idx, store_idx, cache_store_idx, caches_v, v_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <paddle::DataType D, int VecSize = 4>
|
||||||
|
void CacheKVWithRopeKernel(
|
||||||
|
const paddle::Tensor& qkv, // token_num, head_num * head_dim
|
||||||
|
paddle::Tensor&
|
||||||
|
caches_k, // max_block_num, block_size, kv_head_num, head_dim
|
||||||
|
paddle::Tensor&
|
||||||
|
caches_v, // max_block_num, block_size, kv_head_num, head_dim
|
||||||
|
const paddle::Tensor& block_tables, // bs, block_num
|
||||||
|
const paddle::Tensor& rotary_cos,
|
||||||
|
const paddle::Tensor& rotary_sin,
|
||||||
|
const paddle::Tensor& cu_seqlens_q, // bs + 1
|
||||||
|
const paddle::Tensor& batch_ids_q, // token_num
|
||||||
|
const int q_head_num,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int block_size,
|
||||||
|
paddle::Tensor& q_out,
|
||||||
|
paddle::Tensor& k_out,
|
||||||
|
paddle::Tensor& v_out) {
|
||||||
|
typedef PDTraits<D> traits_;
|
||||||
|
typedef typename traits_::DataType DataType_;
|
||||||
|
typedef typename traits_::data_t data_t;
|
||||||
|
|
||||||
|
const int all_num_elements = qkv.numel();
|
||||||
|
const int all_num_heads = q_head_num + 2 * kv_head_num;
|
||||||
|
auto stream = qkv.stream();
|
||||||
|
|
||||||
|
dim3 block_dims(1, 4, (head_dim + VecSize - 1) / VecSize);
|
||||||
|
dim3 grid_dims(all_num_elements / (all_num_heads * head_dim), // token
|
||||||
|
(std::max(q_head_num, kv_head_num) + block_dims.y - 1) /
|
||||||
|
block_dims.y, // head
|
||||||
|
(head_dim + (block_dims.z * VecSize) - 1) /
|
||||||
|
(block_dims.z * VecSize) // dim: load Vec at a time
|
||||||
|
);
|
||||||
|
|
||||||
|
// printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z);
|
||||||
|
// printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z);
|
||||||
|
|
||||||
|
CacheKVWithRopeParams param;
|
||||||
|
param.head_dim = head_dim;
|
||||||
|
param.block_size = block_size;
|
||||||
|
param.block_num = static_cast<int>(block_tables.shape().back());
|
||||||
|
param.cache_stride = block_size * kv_head_num * head_dim;
|
||||||
|
param.token_stride = all_num_heads * head_dim;
|
||||||
|
param.head_stride = head_dim;
|
||||||
|
param.q_stride = q_head_num * head_dim;
|
||||||
|
param.kv_stride = kv_head_num * head_dim;
|
||||||
|
param.q_head_offset = 0;
|
||||||
|
param.k_head_offset = q_head_num;
|
||||||
|
param.v_head_offset = q_head_num + kv_head_num;
|
||||||
|
param.q_head_num = q_head_num;
|
||||||
|
param.kv_head_num = kv_head_num;
|
||||||
|
|
||||||
|
if (qkv.dtype() == rotary_cos.dtype()) {
|
||||||
|
DispatchCacheKVWithRopeVecKernel<DataType_, DataType_, VecSize>
|
||||||
|
<<<grid_dims, block_dims, 0, stream>>>(
|
||||||
|
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
|
||||||
|
reinterpret_cast<const int*>(block_tables.data<int>()),
|
||||||
|
reinterpret_cast<const DataType_*>(rotary_cos.data<data_t>()),
|
||||||
|
reinterpret_cast<const DataType_*>(rotary_sin.data<data_t>()),
|
||||||
|
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
|
||||||
|
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
|
||||||
|
param,
|
||||||
|
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||||
|
} else if (rotary_cos.dtype() == paddle::DataType::FLOAT32) {
|
||||||
|
DispatchCacheKVWithRopeVecKernel<DataType_, float, VecSize>
|
||||||
|
<<<grid_dims, block_dims, 0, stream>>>(
|
||||||
|
reinterpret_cast<const DataType_*>(qkv.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(caches_k.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(caches_v.data<data_t>()),
|
||||||
|
reinterpret_cast<const int*>(block_tables.data<int>()),
|
||||||
|
reinterpret_cast<const float*>(rotary_cos.data<float>()),
|
||||||
|
reinterpret_cast<const float*>(rotary_sin.data<float>()),
|
||||||
|
reinterpret_cast<const int*>(cu_seqlens_q.data<int>()),
|
||||||
|
reinterpret_cast<const int*>(batch_ids_q.data<int>()),
|
||||||
|
param,
|
||||||
|
reinterpret_cast<DataType_*>(q_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(k_out.data<data_t>()),
|
||||||
|
reinterpret_cast<DataType_*>(v_out.data<data_t>()));
|
||||||
|
} else {
|
||||||
|
PD_THROW("Unsupported qk dtype and rope dtype.");
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t err = cudaGetLastError();
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
printf("CUDA Error: %s\n", cudaGetErrorString(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::Tensor> CacheKVWithRope(
|
||||||
|
const paddle::Tensor& qkv, // token_num, head_num * head_dim
|
||||||
|
paddle::Tensor&
|
||||||
|
caches_k, // max_block_num, block_size, kv_head_num, head_dim
|
||||||
|
paddle::Tensor&
|
||||||
|
caches_v, // max_block_num, block_size, kv_head_num, head_dim
|
||||||
|
const paddle::Tensor& block_tables, // bs, block_num
|
||||||
|
const paddle::Tensor& rotary_cos,
|
||||||
|
const paddle::Tensor& rotary_sin,
|
||||||
|
const paddle::Tensor& cu_seqlens_q, // bs + 1
|
||||||
|
const paddle::Tensor& batch_ids_q, // token_num
|
||||||
|
const int q_head_num,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int block_size) {
|
||||||
|
auto qkv_shape = qkv.shape();
|
||||||
|
auto token_num = qkv_shape[0];
|
||||||
|
auto place = qkv.place();
|
||||||
|
auto dtype = qkv.dtype();
|
||||||
|
common::DDim q_out_shape, kv_out_shape;
|
||||||
|
if (rotary_cos.shape().size() == 3) {
|
||||||
|
q_out_shape = {token_num, q_head_num, head_dim};
|
||||||
|
kv_out_shape = {token_num, kv_head_num, head_dim};
|
||||||
|
} else {
|
||||||
|
q_out_shape = {token_num, 1, q_head_num, head_dim};
|
||||||
|
kv_out_shape = {token_num, 1, kv_head_num, head_dim};
|
||||||
|
}
|
||||||
|
auto q_out = GetEmptyTensor(q_out_shape, dtype, place);
|
||||||
|
auto k_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||||
|
auto v_out = GetEmptyTensor(kv_out_shape, dtype, place);
|
||||||
|
|
||||||
|
if (token_num == 0) {
|
||||||
|
return {q_out, k_out, v_out};
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(qkv_shape.back(),
|
||||||
|
((q_head_num + 2 * kv_head_num) * head_dim),
|
||||||
|
"The last dimension of qkv [%d] must equal to {(q_head_num "
|
||||||
|
"+ 2 * kv_head_num) * head_dim [%d].",
|
||||||
|
qkv_shape.back(),
|
||||||
|
((q_head_num + 2 * kv_head_num) * head_dim));
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
head_dim % 2,
|
||||||
|
0,
|
||||||
|
"The last dimension (head_dim) of qkv must be an even number "
|
||||||
|
"for RoPE, but got %d",
|
||||||
|
head_dim);
|
||||||
|
PADDLE_ENFORCE_EQ(q_out.shape().back(),
|
||||||
|
rotary_cos.shape().back(),
|
||||||
|
"The last dimension of cos mismatches that of q, "
|
||||||
|
"expect %d but got %d",
|
||||||
|
q_out.shape().back(),
|
||||||
|
rotary_cos.shape().back());
|
||||||
|
|
||||||
|
switch (dtype) {
|
||||||
|
case paddle::DataType::BFLOAT16:
|
||||||
|
CacheKVWithRopeKernel<paddle::DataType::BFLOAT16>(qkv,
|
||||||
|
caches_k,
|
||||||
|
caches_v,
|
||||||
|
block_tables,
|
||||||
|
rotary_cos,
|
||||||
|
rotary_sin,
|
||||||
|
cu_seqlens_q,
|
||||||
|
batch_ids_q,
|
||||||
|
q_head_num,
|
||||||
|
kv_head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
q_out,
|
||||||
|
k_out,
|
||||||
|
v_out);
|
||||||
|
break;
|
||||||
|
case paddle::DataType::FLOAT16:
|
||||||
|
CacheKVWithRopeKernel<paddle::DataType::FLOAT16>(qkv,
|
||||||
|
caches_k,
|
||||||
|
caches_v,
|
||||||
|
block_tables,
|
||||||
|
rotary_cos,
|
||||||
|
rotary_sin,
|
||||||
|
cu_seqlens_q,
|
||||||
|
batch_ids_q,
|
||||||
|
q_head_num,
|
||||||
|
kv_head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
q_out,
|
||||||
|
k_out,
|
||||||
|
v_out);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
PD_THROW("Only support qk dtype of BF16 and F16");
|
||||||
|
}
|
||||||
|
|
||||||
|
return {q_out, k_out, v_out};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<int64_t>> CacheKVWithRopeInferShape(
|
||||||
|
const std::vector<int64_t>& qkv_shape,
|
||||||
|
const std::vector<int64_t>& caches_k_shape,
|
||||||
|
const std::vector<int64_t>& caches_v_shape,
|
||||||
|
const std::vector<int64_t>& block_tables_shape,
|
||||||
|
const std::vector<int64_t>& cos_shape,
|
||||||
|
const std::vector<int64_t>& sin_shape,
|
||||||
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||||
|
const std::vector<int64_t>& batch_ids_q_shape) {
|
||||||
|
return {qkv_shape,
|
||||||
|
caches_k_shape,
|
||||||
|
caches_v_shape,
|
||||||
|
block_tables_shape,
|
||||||
|
cos_shape,
|
||||||
|
sin_shape,
|
||||||
|
cu_seqlens_q_shape,
|
||||||
|
batch_ids_q_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<paddle::DataType> CacheKVWithRopeInferDtype(
|
||||||
|
const paddle::DataType& qkv_dtype,
|
||||||
|
const paddle::DataType& caches_k_dtype,
|
||||||
|
const paddle::DataType& caches_v_dtype,
|
||||||
|
const paddle::DataType& block_tables_dtype,
|
||||||
|
const paddle::DataType& cos_dtype,
|
||||||
|
const paddle::DataType& sin_dtype,
|
||||||
|
const paddle::DataType& cu_seqlens_q_dtype,
|
||||||
|
const paddle::DataType& batch_ids_q_dtype) {
|
||||||
|
return {qkv_dtype,
|
||||||
|
caches_k_dtype,
|
||||||
|
caches_v_dtype,
|
||||||
|
block_tables_dtype,
|
||||||
|
cos_dtype,
|
||||||
|
sin_dtype,
|
||||||
|
cu_seqlens_q_dtype,
|
||||||
|
batch_ids_q_dtype};
|
||||||
|
}
|
||||||
|
|
||||||
|
PD_BUILD_OP(cache_kv_with_rope)
|
||||||
|
.Inputs({"qkv",
|
||||||
|
"caches_k",
|
||||||
|
"caches_v",
|
||||||
|
"block_tables",
|
||||||
|
"rotary_cos",
|
||||||
|
"rotary_sin",
|
||||||
|
"cu_seqlen_q",
|
||||||
|
"batch_ids_q"})
|
||||||
|
.Outputs({"q_out", "k_out", "v_out"})
|
||||||
|
.Attrs(
|
||||||
|
{"q_head_num:int", "kv_head_num:int", "head_dim:int", "block_size:int"})
|
||||||
|
.SetKernelFn(PD_KERNEL(CacheKVWithRope))
|
||||||
|
.SetInferShapeFn(PD_INFER_SHAPE(CacheKVWithRopeInferShape))
|
||||||
|
.SetInferDtypeFn(PD_INFER_DTYPE(CacheKVWithRopeInferDtype));
|
||||||
@@ -14,9 +14,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fused_moe_op.h"
|
#include "fused_moe_helper.h"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "mc_fused_moe_helper.h"
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
__global__ void compute_total_rows_before_expert_kernel(
|
__global__ void compute_total_rows_before_expert_kernel(
|
||||||
int* sorted_experts,
|
int* sorted_experts,
|
||||||
@@ -42,58 +43,61 @@ void compute_total_rows_before_expert(int* sorted_indices,
|
|||||||
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
sorted_indices, total_indices, num_experts, total_rows_before_expert);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <paddle::DataType T,
|
} // namespace phi
|
||||||
typename ElementA,
|
|
||||||
typename ElementB,
|
template <paddle::DataType T>
|
||||||
typename ElementC>
|
|
||||||
void FusedMoeKernel(const paddle::Tensor& input,
|
void FusedMoeKernel(const paddle::Tensor& input,
|
||||||
const paddle::Tensor& gate_weight,
|
const paddle::Tensor& gate_weight,
|
||||||
const paddle::Tensor& ffn1_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::Tensor& ffn2_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||||
const std::string& quant_method,
|
const std::string& quant_method,
|
||||||
const int moe_topk,
|
const int moe_topk,
|
||||||
const bool group_moe,
|
const bool group_moe,
|
||||||
const bool norm_topk_prob,
|
const bool norm_topk_prob,
|
||||||
paddle::Tensor* output) {
|
paddle::Tensor* output) {
|
||||||
|
using namespace phi;
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
|
|
||||||
auto* output_data = output->data<data_t>();
|
auto* output_data = output->data<data_t>();
|
||||||
|
|
||||||
auto moe_compute =
|
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
|
||||||
McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
|
|
||||||
|
|
||||||
moe_compute.computeFFN(&input,
|
auto moe_compute =
|
||||||
&gate_weight,
|
McMoeHelper<data_t, DataType_>(quant_method, &int8_moe_gemm_runner);
|
||||||
&ffn1_weight,
|
|
||||||
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
|
moe_compute.computeFFN(
|
||||||
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
|
&input,
|
||||||
&ffn2_weight,
|
&gate_weight,
|
||||||
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
|
&up_gate_proj_weight,
|
||||||
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
|
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
|
||||||
nullptr,
|
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
|
||||||
moe_topk,
|
&down_proj_weight,
|
||||||
group_moe,
|
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
|
||||||
norm_topk_prob,
|
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
|
||||||
1.0, // ComputeFFN
|
nullptr,
|
||||||
"ffn",
|
moe_topk,
|
||||||
output);
|
group_moe,
|
||||||
|
norm_topk_prob,
|
||||||
|
1.0, // ComputeFFN
|
||||||
|
"ffn",
|
||||||
|
output);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> FusedExpertMoe(
|
std::vector<paddle::Tensor> FusedExpertMoe(
|
||||||
const paddle::Tensor& input,
|
const paddle::Tensor& input,
|
||||||
const paddle::Tensor& gate_weight,
|
const paddle::Tensor& gate_weight,
|
||||||
const paddle::Tensor& ffn1_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& ffn2_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const std::string& quant_method,
|
const std::string& quant_method,
|
||||||
const int moe_topk,
|
const int moe_topk,
|
||||||
const bool norm_topk_prob,
|
const bool norm_topk_prob,
|
||||||
@@ -107,40 +111,22 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
|||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
FusedMoeKernel<paddle::DataType::BFLOAT16,
|
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
|
||||||
maca_bfloat16,
|
gate_weight,
|
||||||
int8_t,
|
up_gate_proj_weight,
|
||||||
maca_bfloat16>(input,
|
up_gate_proj_scale,
|
||||||
gate_weight,
|
up_gate_proj_bias,
|
||||||
ffn1_weight,
|
down_proj_weight,
|
||||||
ffn1_scale,
|
down_proj_scale,
|
||||||
ffn1_bias,
|
down_proj_bias,
|
||||||
ffn2_weight,
|
quant_method,
|
||||||
ffn2_scale,
|
moe_topk,
|
||||||
ffn2_bias,
|
group_moe,
|
||||||
quant_method,
|
norm_topk_prob,
|
||||||
moe_topk,
|
&output);
|
||||||
group_moe,
|
|
||||||
norm_topk_prob,
|
|
||||||
&output);
|
|
||||||
break;
|
break;
|
||||||
// case paddle::DataType::FLOAT16:
|
|
||||||
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
|
|
||||||
// gate_weight,
|
|
||||||
// ffn1_weight,
|
|
||||||
// ffn1_scale,
|
|
||||||
// ffn1_bias,
|
|
||||||
// ffn2_weight,
|
|
||||||
// ffn2_scale,
|
|
||||||
// ffn2_bias,
|
|
||||||
// quant_method,
|
|
||||||
// moe_topk,
|
|
||||||
// group_moe,
|
|
||||||
// norm_topk_prob,
|
|
||||||
// &output);
|
|
||||||
// break;
|
|
||||||
default:
|
default:
|
||||||
PD_THROW("Only support bf16 for FusedMoeKernel");
|
PD_THROW("Unsupported data type for FusedMoeKernel");
|
||||||
}
|
}
|
||||||
return {output};
|
return {output};
|
||||||
}
|
}
|
||||||
@@ -148,36 +134,36 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
|||||||
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
|
||||||
const std::vector<int64_t>& input_shape,
|
const std::vector<int64_t>& input_shape,
|
||||||
const std::vector<int64_t>& gate_weight_shape,
|
const std::vector<int64_t>& gate_weight_shape,
|
||||||
const std::vector<int64_t>& ffn1_weight_shape,
|
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||||
const std::vector<int64_t>& ffn2_weight_shape,
|
const std::vector<int64_t>& down_proj_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
|
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
|
||||||
return {input_shape};
|
return {input_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
|
||||||
const paddle::DataType& input_dtype,
|
const paddle::DataType& input_dtype,
|
||||||
const paddle::DataType& gate_weight_dtype,
|
const paddle::DataType& gate_weight_dtype,
|
||||||
const paddle::DataType& ffn1_weight_dtype,
|
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||||
const paddle::DataType& ffn2_weight_dtype,
|
const paddle::DataType& down_proj_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
|
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
|
||||||
return {input_dtype};
|
return {input_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(fused_expert_moe)
|
PD_BUILD_STATIC_OP(fused_expert_moe)
|
||||||
.Inputs({"input",
|
.Inputs({"input",
|
||||||
"gate_weight",
|
"gate_weight",
|
||||||
"ffn1_weight",
|
"up_gate_proj_weight",
|
||||||
"ffn2_weight",
|
"down_proj_weight",
|
||||||
paddle::Optional("ffn1_bias"),
|
paddle::Optional("up_gate_proj_bias"),
|
||||||
paddle::Optional("ffn1_scale"),
|
paddle::Optional("up_gate_proj_scale"),
|
||||||
paddle::Optional("ffn2_bias"),
|
paddle::Optional("down_proj_bias"),
|
||||||
paddle::Optional("ffn2_scale")})
|
paddle::Optional("down_proj_scale")})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.Attrs({"quant_method:std::string",
|
.Attrs({"quant_method:std::string",
|
||||||
"moe_topk:int",
|
"moe_topk:int",
|
||||||
|
|||||||
199
custom_ops/metax_ops/fused_moe_gemm_kernels.h
Normal file
199
custom_ops/metax_ops/fused_moe_gemm_kernels.h
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mctlass/numeric_conversion.h"
|
||||||
|
#include "mctlassEx/mctlassEx.h"
|
||||||
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct mctlassExDataTraits;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct mctlassExDataTraits<maca_bfloat16> {
|
||||||
|
static constexpr mctlassExDataType type =
|
||||||
|
mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct mctlassExDataTraits<int8_t> {
|
||||||
|
static constexpr mctlassExDataType type =
|
||||||
|
mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename WeightType>
|
||||||
|
class McMoeGemmRunner {
|
||||||
|
public:
|
||||||
|
McMoeGemmRunner() {}
|
||||||
|
|
||||||
|
void mc_grouped_gemm_basic_kernel(const T* ptrA,
|
||||||
|
mctlassExOrder_t majorA,
|
||||||
|
const WeightType* ptrB,
|
||||||
|
mctlassExOrder_t majorB,
|
||||||
|
const T* ptrScale,
|
||||||
|
const T* ptrBias,
|
||||||
|
T* ptrC,
|
||||||
|
mctlassExOrder_t majorC,
|
||||||
|
const int* ptrSegInd,
|
||||||
|
int* ptrMNumTilesInd,
|
||||||
|
int numExperts,
|
||||||
|
int m, // expanded_active_expert_rows
|
||||||
|
int n, // inter_dim
|
||||||
|
int k, // hidden_size
|
||||||
|
mcStream_t stream) {
|
||||||
|
mctlassExHandle_t handle;
|
||||||
|
mctlassExHandleCreate(&handle);
|
||||||
|
|
||||||
|
mctlassExDataType DataType_ = mctlassExDataTraits<T>::type;
|
||||||
|
mctlassExDataType WeightType_ = mctlassExDataTraits<WeightType>::type;
|
||||||
|
|
||||||
|
mctlassExMatrixLayout_t matLayoutA;
|
||||||
|
mctlassExMatrixLayout_t matLayoutB;
|
||||||
|
mctlassExMatrixLayout_t matLayoutC;
|
||||||
|
|
||||||
|
// mat A: (m, k)
|
||||||
|
mctlassExMatrixLayoutCreate(&matLayoutA, DataType_, m, k, k);
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutA,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||||
|
&majorA,
|
||||||
|
sizeof(mctlassExOrder_t));
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutA,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&numExperts,
|
||||||
|
sizeof(int));
|
||||||
|
// mat B: (num_experts, n, k)
|
||||||
|
mctlassExMatrixLayoutCreate(&matLayoutB, WeightType_, k, n, k);
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutB,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||||
|
&majorB,
|
||||||
|
sizeof(mctlassExOrder_t));
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutB,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&numExperts,
|
||||||
|
sizeof(int));
|
||||||
|
// mat C: (m, n)
|
||||||
|
mctlassExMatrixLayoutCreate(&matLayoutC, DataType_, m, n, n);
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutC,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
||||||
|
&majorC,
|
||||||
|
sizeof(mctlassExOrder_t));
|
||||||
|
mctlassExMatrixLayoutSetAttribute(
|
||||||
|
matLayoutC,
|
||||||
|
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
||||||
|
&numExperts,
|
||||||
|
sizeof(int));
|
||||||
|
// bias: (num_experts, n)
|
||||||
|
// scale: (num, n)
|
||||||
|
|
||||||
|
mctlassExDesc_t mctlass_desc;
|
||||||
|
mctlassExCreateDesc(&mctlass_desc);
|
||||||
|
mctlassExDataType input_type = DataType_;
|
||||||
|
mctlassExDataType scale_type = WeightType_;
|
||||||
|
mctlassExDataType compute_type =
|
||||||
|
mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
||||||
|
mctlassExEpilogueType epilogue_type =
|
||||||
|
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
||||||
|
if (ptrBias) {
|
||||||
|
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
||||||
|
}
|
||||||
|
// set scale
|
||||||
|
mctlassExDescSetAttribute(
|
||||||
|
mctlass_desc,
|
||||||
|
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
||||||
|
&ptrScale,
|
||||||
|
sizeof(ptrScale));
|
||||||
|
mctlassExDescSetAttribute(
|
||||||
|
mctlass_desc,
|
||||||
|
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
||||||
|
&input_type,
|
||||||
|
sizeof(mctlassExDataType));
|
||||||
|
// set bias
|
||||||
|
if (ptrBias) {
|
||||||
|
mctlassExDescSetAttribute(
|
||||||
|
mctlass_desc,
|
||||||
|
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
||||||
|
&ptrBias,
|
||||||
|
sizeof(ptrBias));
|
||||||
|
}
|
||||||
|
// set coumpute type
|
||||||
|
mctlassExDescSetAttribute(
|
||||||
|
mctlass_desc,
|
||||||
|
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
||||||
|
&compute_type,
|
||||||
|
sizeof(mctlassExDataType));
|
||||||
|
// set epilogue type
|
||||||
|
mctlassExDescSetAttribute(
|
||||||
|
mctlass_desc,
|
||||||
|
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
||||||
|
&epilogue_type,
|
||||||
|
sizeof(mctlassExEpilogueType));
|
||||||
|
|
||||||
|
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
||||||
|
mctlassExContiguousGroupedGemmAlgo_t::
|
||||||
|
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
||||||
|
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
||||||
|
mctlassExContiguousGroupedDescCreate(
|
||||||
|
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
||||||
|
int blocksizeM;
|
||||||
|
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
||||||
|
mctlass_desc,
|
||||||
|
matLayoutA,
|
||||||
|
matLayoutB,
|
||||||
|
matLayoutC,
|
||||||
|
&algo,
|
||||||
|
&blocksizeM);
|
||||||
|
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
||||||
|
mctlass_desc,
|
||||||
|
matLayoutA,
|
||||||
|
matLayoutB,
|
||||||
|
matLayoutC,
|
||||||
|
&algo,
|
||||||
|
contiguous_group_desc,
|
||||||
|
numExperts,
|
||||||
|
blocksizeM,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
mctlassExContiguousGroupedGemmBasic(handle,
|
||||||
|
mctlass_desc,
|
||||||
|
ptrA,
|
||||||
|
matLayoutA,
|
||||||
|
ptrB,
|
||||||
|
matLayoutB,
|
||||||
|
ptrC,
|
||||||
|
matLayoutC,
|
||||||
|
contiguous_group_desc,
|
||||||
|
&algo,
|
||||||
|
nullptr,
|
||||||
|
0,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
mctlassExHandleDestroy(handle);
|
||||||
|
mctlassExMatrixLayoutDestroy(matLayoutA);
|
||||||
|
mctlassExMatrixLayoutDestroy(matLayoutB);
|
||||||
|
mctlassExMatrixLayoutDestroy(matLayoutC);
|
||||||
|
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
||||||
|
mctlassExDestroyDesc(mctlass_desc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template class McMoeGemmRunner<maca_bfloat16, int8_t>;
|
||||||
|
|
||||||
|
} // namespace phi
|
||||||
@@ -14,14 +14,17 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
|
#include "fused_moe_gemm_kernels.h"
|
||||||
|
#include "fused_moe_imp_op.h"
|
||||||
#include "fused_moe_op.h"
|
#include "fused_moe_op.h"
|
||||||
|
#include "mctlass/numeric_conversion.h"
|
||||||
|
#include "mctlassEx/mctlassEx.h"
|
||||||
|
|
||||||
using namespace phi;
|
namespace phi {
|
||||||
|
|
||||||
template <typename T, int VecSize>
|
template <typename T, int VecSize>
|
||||||
__global__ void moe_token_type_ids_kernel(T *gating_output,
|
__global__ void moe_token_type_ids_kernel(T* gating_output,
|
||||||
const int *moe_token_type_ids_out,
|
const int* moe_token_type_ids_out,
|
||||||
const int num_rows,
|
const int num_rows,
|
||||||
const int num_experts,
|
const int num_experts,
|
||||||
const int k) {
|
const int k) {
|
||||||
@@ -40,8 +43,8 @@ __global__ void moe_token_type_ids_kernel(T *gating_output,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void moe_token_type_ids_kernelLauncher(T *gating_output,
|
void moe_token_type_ids_kernelLauncher(T* gating_output,
|
||||||
const int *moe_token_type_ids_out,
|
const int* moe_token_type_ids_out,
|
||||||
const int num_rows,
|
const int num_rows,
|
||||||
const int num_experts,
|
const int num_experts,
|
||||||
const int k,
|
const int k,
|
||||||
@@ -51,3 +54,338 @@ void moe_token_type_ids_kernelLauncher(T *gating_output,
|
|||||||
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
|
||||||
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename MacaType>
|
||||||
|
class McMoeHelper {
|
||||||
|
public:
|
||||||
|
McMoeHelper(const std::string gemm_method,
|
||||||
|
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner)
|
||||||
|
: gemm_method_(gemm_method),
|
||||||
|
int8_moe_gemm_runner_(int8_moe_gemm_runner) {}
|
||||||
|
|
||||||
|
// -------- getWorkspaceSize -------- //
|
||||||
|
template <typename KeyT>
|
||||||
|
size_t getWorkspaceSize(const int64_t num_rows,
|
||||||
|
const int64_t hidden_size,
|
||||||
|
const int64_t inter_size,
|
||||||
|
const int64_t num_experts,
|
||||||
|
const int64_t k) {
|
||||||
|
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||||
|
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||||
|
const size_t padded_experts = AlignTo16(num_experts);
|
||||||
|
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||||
|
// softmax output, permuted_rows and permuted_experts have moved to outside
|
||||||
|
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
||||||
|
// FfnLayer forward.
|
||||||
|
size_t total_ws_bytes =
|
||||||
|
5 * num_moe_inputs *
|
||||||
|
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
||||||
|
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
||||||
|
total_ws_bytes +=
|
||||||
|
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
||||||
|
|
||||||
|
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
||||||
|
const size_t sorter_ws_size_bytes =
|
||||||
|
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
||||||
|
sorter_.update_num_experts(num_experts);
|
||||||
|
|
||||||
|
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
||||||
|
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
||||||
|
int64_t remaining_bytes =
|
||||||
|
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
||||||
|
bytes_for_intermediate_and_sorting += remaining_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
total_ws_bytes +=
|
||||||
|
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
||||||
|
// sorting workspace
|
||||||
|
|
||||||
|
int64_t num_softmax_outs = 0;
|
||||||
|
const bool is_pow_2 =
|
||||||
|
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||||
|
if (!is_pow_2 || num_experts > 256) {
|
||||||
|
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
||||||
|
}
|
||||||
|
|
||||||
|
total_ws_bytes += num_softmax_outs * sizeof(float);
|
||||||
|
|
||||||
|
return total_ws_bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
void computeFFN(const paddle::Tensor* input,
|
||||||
|
const paddle::Tensor* gate_weight,
|
||||||
|
const paddle::Tensor* up_gate_proj_weight,
|
||||||
|
const paddle::Tensor* up_gate_proj_scale,
|
||||||
|
const paddle::Tensor* up_gate_proj_bias,
|
||||||
|
const paddle::Tensor* down_proj_weight,
|
||||||
|
const paddle::Tensor* down_proj_scale,
|
||||||
|
const paddle::Tensor* down_proj_bias,
|
||||||
|
const paddle::Tensor* moe_token_type_ids,
|
||||||
|
const int moe_topk,
|
||||||
|
const bool group_moe,
|
||||||
|
const bool norm_topk_prob,
|
||||||
|
const float routed_scaling_factor,
|
||||||
|
const std::string moe_type,
|
||||||
|
paddle::Tensor* output) {
|
||||||
|
auto* input_activations = input->data<T>();
|
||||||
|
auto* gating_weights = gate_weight->data<float>();
|
||||||
|
const T* fc1_expert_biases =
|
||||||
|
up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
|
||||||
|
const T* fc2_expert_biases =
|
||||||
|
down_proj_bias ? down_proj_bias->data<T>() : nullptr;
|
||||||
|
|
||||||
|
auto* output_ = output->data<T>();
|
||||||
|
auto stream = input->stream();
|
||||||
|
auto place = input->place();
|
||||||
|
auto input_type = input->dtype();
|
||||||
|
|
||||||
|
auto input_dims = input->dims();
|
||||||
|
auto up_gate_proj_dims = up_gate_proj_weight->dims();
|
||||||
|
int64_t token_num = 0;
|
||||||
|
if (input_dims.size() == 3) {
|
||||||
|
token_num = input_dims[0] * input_dims[1];
|
||||||
|
} else {
|
||||||
|
token_num = input_dims[0];
|
||||||
|
}
|
||||||
|
const int64_t num_rows = token_num;
|
||||||
|
|
||||||
|
const int64_t hidden_size = up_gate_proj_dims[2];
|
||||||
|
int64_t inter_dim = 0;
|
||||||
|
if (moe_type == "qkv") {
|
||||||
|
inter_dim =
|
||||||
|
up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
|
||||||
|
} else {
|
||||||
|
inter_dim = up_gate_proj_dims[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// if (gemm_method_ == "weight_only_int4") {
|
||||||
|
// inter_dim = inter_dim * 2;
|
||||||
|
// }
|
||||||
|
|
||||||
|
const int64_t inter_size = inter_dim;
|
||||||
|
const int64_t num_experts = up_gate_proj_dims[0];
|
||||||
|
const int64_t k = moe_topk;
|
||||||
|
|
||||||
|
int64_t bytes =
|
||||||
|
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
||||||
|
|
||||||
|
// Pointers
|
||||||
|
int* expert_for_source_row;
|
||||||
|
int* source_rows_;
|
||||||
|
int* permuted_rows_;
|
||||||
|
int* permuted_experts_;
|
||||||
|
int* expanded_source_row_to_expanded_dest_row;
|
||||||
|
|
||||||
|
T* permuted_data_;
|
||||||
|
int32_t* total_rows_before_expert_;
|
||||||
|
T* fc1_result_;
|
||||||
|
float* softmax_out_;
|
||||||
|
|
||||||
|
paddle::Tensor ws_ptr_tensor =
|
||||||
|
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
||||||
|
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
||||||
|
|
||||||
|
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
||||||
|
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
||||||
|
const int64_t padded_experts = AlignTo16(num_experts);
|
||||||
|
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
||||||
|
|
||||||
|
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
||||||
|
source_rows_ = expert_for_source_row + num_moe_inputs;
|
||||||
|
permuted_rows_ = source_rows_ + num_moe_inputs;
|
||||||
|
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
||||||
|
expanded_source_row_to_expanded_dest_row =
|
||||||
|
permuted_experts_ + num_moe_inputs;
|
||||||
|
permuted_data_ = reinterpret_cast<T*>(
|
||||||
|
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
||||||
|
total_rows_before_expert_ =
|
||||||
|
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
||||||
|
fc1_result_ =
|
||||||
|
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
||||||
|
|
||||||
|
const bool is_pow_2 =
|
||||||
|
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||||
|
if (!is_pow_2 || num_experts > 256) {
|
||||||
|
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
||||||
|
} else {
|
||||||
|
softmax_out_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle::Tensor expert_scales_float_tensor =
|
||||||
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
|
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
||||||
|
|
||||||
|
float* softmax_max_prob = nullptr;
|
||||||
|
if (group_moe) {
|
||||||
|
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
||||||
|
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
|
// (TODO: check fill success ?)
|
||||||
|
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
||||||
|
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle::Tensor fc1_out_tensor =
|
||||||
|
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
||||||
|
T* fc1_out = fc1_out_tensor.data<T>();
|
||||||
|
|
||||||
|
auto input_cast_tensor =
|
||||||
|
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
||||||
|
auto gate_tensor =
|
||||||
|
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
||||||
|
float* gating_output = gate_tensor.data<float>();
|
||||||
|
|
||||||
|
if (moe_token_type_ids) {
|
||||||
|
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
||||||
|
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
||||||
|
moe_token_type_ids_out,
|
||||||
|
num_rows,
|
||||||
|
num_experts,
|
||||||
|
k,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
topk_gating_softmax_kernelLauncher<float, int>(gating_output,
|
||||||
|
nullptr,
|
||||||
|
expert_scales_float,
|
||||||
|
softmax_out_,
|
||||||
|
expert_for_source_row,
|
||||||
|
source_rows_,
|
||||||
|
softmax_max_prob,
|
||||||
|
num_rows,
|
||||||
|
num_experts,
|
||||||
|
k,
|
||||||
|
group_moe,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
const int64_t sorter_ws_size_bytes =
|
||||||
|
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
||||||
|
|
||||||
|
sorter_.run(fc1_result_,
|
||||||
|
sorter_ws_size_bytes,
|
||||||
|
expert_for_source_row,
|
||||||
|
permuted_experts_,
|
||||||
|
source_rows_,
|
||||||
|
permuted_rows_,
|
||||||
|
k * num_rows,
|
||||||
|
false,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
initialize_moe_routing_kernelLauncher(
|
||||||
|
input_activations,
|
||||||
|
permuted_data_,
|
||||||
|
permuted_rows_,
|
||||||
|
nullptr,
|
||||||
|
nullptr,
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
num_rows,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
k,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
const int64_t expanded_active_expert_rows = k * num_rows;
|
||||||
|
|
||||||
|
compute_total_rows_before_expert(permuted_experts_,
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
num_experts,
|
||||||
|
total_rows_before_expert_,
|
||||||
|
stream);
|
||||||
|
|
||||||
|
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
||||||
|
mctlassExOrder_t column_major =
|
||||||
|
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
||||||
|
auto m_num_tile =
|
||||||
|
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
||||||
|
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
||||||
|
|
||||||
|
if (gemm_method_ == "weight_only_int8") {
|
||||||
|
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
||||||
|
reinterpret_cast<const MacaType*>(permuted_data_),
|
||||||
|
row_major,
|
||||||
|
reinterpret_cast<const int8_t*>(up_gate_proj_weight->data<int8_t>()),
|
||||||
|
column_major,
|
||||||
|
reinterpret_cast<const MacaType*>(up_gate_proj_scale->data<T>()),
|
||||||
|
reinterpret_cast<const MacaType*>(fc1_expert_biases),
|
||||||
|
reinterpret_cast<MacaType*>(fc1_out),
|
||||||
|
row_major,
|
||||||
|
total_rows_before_expert_,
|
||||||
|
m_num_tile_ptr,
|
||||||
|
num_experts,
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
inter_size,
|
||||||
|
hidden_size,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (moe_type == "ffn") {
|
||||||
|
auto act_out_tensor =
|
||||||
|
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||||
|
auto act_out = act_out_tensor.data<T>();
|
||||||
|
|
||||||
|
paddle::Tensor fc2_output_tensor =
|
||||||
|
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
||||||
|
T* fc2_result = fc2_output_tensor.data<T>();
|
||||||
|
|
||||||
|
if (gemm_method_ == "weight_only_int8") {
|
||||||
|
int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel(
|
||||||
|
reinterpret_cast<const MacaType*>(act_out),
|
||||||
|
row_major,
|
||||||
|
reinterpret_cast<const int8_t*>(down_proj_weight->data<int8_t>()),
|
||||||
|
column_major,
|
||||||
|
reinterpret_cast<const MacaType*>(down_proj_scale->data<T>()),
|
||||||
|
nullptr,
|
||||||
|
reinterpret_cast<MacaType*>(fc2_result),
|
||||||
|
row_major,
|
||||||
|
total_rows_before_expert_,
|
||||||
|
m_num_tile_ptr,
|
||||||
|
num_experts,
|
||||||
|
expanded_active_expert_rows,
|
||||||
|
hidden_size,
|
||||||
|
inter_size / 2,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported gemm method: " + gemm_method_);
|
||||||
|
}
|
||||||
|
|
||||||
|
finalize_moe_routing_kernelLauncher(
|
||||||
|
fc2_result,
|
||||||
|
output_,
|
||||||
|
fc2_expert_biases,
|
||||||
|
reinterpret_cast<float*>(expert_scales_float),
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
expert_for_source_row,
|
||||||
|
num_rows,
|
||||||
|
hidden_size,
|
||||||
|
k,
|
||||||
|
static_cast<int>(1),
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
finalize_moe_routing_kernelLauncher(
|
||||||
|
// fc2_result,
|
||||||
|
fc1_out,
|
||||||
|
output_,
|
||||||
|
fc1_expert_biases, // fc2_expert_biases,
|
||||||
|
reinterpret_cast<float*>(expert_scales_float),
|
||||||
|
expanded_source_row_to_expanded_dest_row,
|
||||||
|
expert_for_source_row,
|
||||||
|
num_rows,
|
||||||
|
inter_size,
|
||||||
|
k,
|
||||||
|
static_cast<int>(0),
|
||||||
|
norm_topk_prob,
|
||||||
|
routed_scaling_factor,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
McMoeGemmRunner<MacaType, int8_t>* int8_moe_gemm_runner_;
|
||||||
|
std::string gemm_method_;
|
||||||
|
CubKeyValueSorter sorter_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace phi
|
||||||
|
|||||||
@@ -20,6 +20,8 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include "cub/cub.cuh"
|
#include "cub/cub.cuh"
|
||||||
|
|
||||||
|
namespace phi {
|
||||||
|
|
||||||
static const float HALF_FLT_MAX = 65504.F;
|
static const float HALF_FLT_MAX = 65504.F;
|
||||||
static const float HALF_FLT_MIN = -65504.F;
|
static const float HALF_FLT_MIN = -65504.F;
|
||||||
static inline size_t AlignTo16(const size_t& input) {
|
static inline size_t AlignTo16(const size_t& input) {
|
||||||
@@ -121,3 +123,5 @@ class CubKeyValueSorter {
|
|||||||
int num_experts_;
|
int num_experts_;
|
||||||
int num_bits_;
|
int num_bits_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // namespace phi
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,486 +0,0 @@
|
|||||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "fused_moe_helper.h"
|
|
||||||
#include "mctlass/numeric_conversion.h"
|
|
||||||
#include "mctlassEx/mctlassEx.h"
|
|
||||||
|
|
||||||
template <typename ElementA, typename ElementB, typename ElementC>
|
|
||||||
void mc_grouped_gemm_basic_kernel(const ElementA* ptrA,
|
|
||||||
mctlassExOrder_t majorA,
|
|
||||||
const ElementB* ptrB,
|
|
||||||
mctlassExOrder_t majorB,
|
|
||||||
const ElementA* ptrScale,
|
|
||||||
const ElementA* ptrBias,
|
|
||||||
ElementC* ptrC,
|
|
||||||
mctlassExOrder_t majorC,
|
|
||||||
const int* ptrSegInd,
|
|
||||||
int* ptrMNumTilesInd,
|
|
||||||
int numExperts,
|
|
||||||
int m, // expanded_active_expert_rows
|
|
||||||
int n, // inter_dim
|
|
||||||
int k, // hidden_size
|
|
||||||
mcStream_t stream) {
|
|
||||||
mctlassExHandle_t handle;
|
|
||||||
mctlassExHandleCreate(&handle);
|
|
||||||
|
|
||||||
mctlassExMatrixLayout_t matLayoutA;
|
|
||||||
mctlassExMatrixLayout_t matLayoutB;
|
|
||||||
mctlassExMatrixLayout_t matLayoutC;
|
|
||||||
|
|
||||||
// mat A: (m, k)
|
|
||||||
mctlassExMatrixLayoutCreate(
|
|
||||||
&matLayoutA, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, k, k);
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutA,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
|
||||||
&majorA,
|
|
||||||
sizeof(mctlassExOrder_t));
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutA,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
|
||||||
&numExperts,
|
|
||||||
sizeof(int));
|
|
||||||
// mat B: (num_experts, n, k)
|
|
||||||
mctlassExMatrixLayoutCreate(
|
|
||||||
&matLayoutB, mctlassExDataType::MCTLASS_EX_DATATYPE_INT8, k, n, k);
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutB,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
|
||||||
&majorB,
|
|
||||||
sizeof(mctlassExOrder_t));
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutB,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
|
||||||
&numExperts,
|
|
||||||
sizeof(int));
|
|
||||||
// mat C: (m, n)
|
|
||||||
mctlassExMatrixLayoutCreate(
|
|
||||||
&matLayoutC, mctlassExDataType::MCTLASS_EX_DATATYPE_BF16, m, n, n);
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutC,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER,
|
|
||||||
&majorC,
|
|
||||||
sizeof(mctlassExOrder_t));
|
|
||||||
mctlassExMatrixLayoutSetAttribute(
|
|
||||||
matLayoutC,
|
|
||||||
mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT,
|
|
||||||
&numExperts,
|
|
||||||
sizeof(int));
|
|
||||||
// bias: (num_experts, n)
|
|
||||||
// scale: (num, n)
|
|
||||||
|
|
||||||
mctlassExDesc_t mctlass_desc;
|
|
||||||
mctlassExCreateDesc(&mctlass_desc);
|
|
||||||
mctlassExDataType input_type = mctlassExDataType::MCTLASS_EX_DATATYPE_BF16;
|
|
||||||
mctlassExDataType scale_type = mctlassExDataType::MCTLASS_EX_DATATYPE_INT8;
|
|
||||||
mctlassExDataType compute_type = mctlassExDataType::MCTLASS_EX_DATATYPE_FP32;
|
|
||||||
mctlassExEpilogueType epilogue_type =
|
|
||||||
mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT;
|
|
||||||
if (ptrBias) {
|
|
||||||
epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS;
|
|
||||||
}
|
|
||||||
// set scale
|
|
||||||
mctlassExDescSetAttribute(
|
|
||||||
mctlass_desc,
|
|
||||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER,
|
|
||||||
&ptrScale,
|
|
||||||
sizeof(ptrScale));
|
|
||||||
mctlassExDescSetAttribute(
|
|
||||||
mctlass_desc,
|
|
||||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE,
|
|
||||||
&input_type,
|
|
||||||
sizeof(mctlassExDataType));
|
|
||||||
// set bias
|
|
||||||
if (ptrBias) {
|
|
||||||
mctlassExDescSetAttribute(
|
|
||||||
mctlass_desc,
|
|
||||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER,
|
|
||||||
&ptrBias,
|
|
||||||
sizeof(ptrBias));
|
|
||||||
}
|
|
||||||
// set coumpute type
|
|
||||||
mctlassExDescSetAttribute(
|
|
||||||
mctlass_desc,
|
|
||||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE,
|
|
||||||
&compute_type,
|
|
||||||
sizeof(mctlassExDataType));
|
|
||||||
// set epilogue type
|
|
||||||
mctlassExDescSetAttribute(
|
|
||||||
mctlass_desc,
|
|
||||||
mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE,
|
|
||||||
&epilogue_type,
|
|
||||||
sizeof(mctlassExEpilogueType));
|
|
||||||
|
|
||||||
const mctlassExContiguousGroupedGemmAlgo_t algo =
|
|
||||||
mctlassExContiguousGroupedGemmAlgo_t::
|
|
||||||
MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT;
|
|
||||||
mctlassExContiguousGroupedDesc_t contiguous_group_desc;
|
|
||||||
mctlassExContiguousGroupedDescCreate(
|
|
||||||
&contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1);
|
|
||||||
int blocksizeM;
|
|
||||||
mctlassExContiguousGroupedGemmGetBlocksizeM(handle,
|
|
||||||
mctlass_desc,
|
|
||||||
matLayoutA,
|
|
||||||
matLayoutB,
|
|
||||||
matLayoutC,
|
|
||||||
&algo,
|
|
||||||
&blocksizeM);
|
|
||||||
mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle,
|
|
||||||
mctlass_desc,
|
|
||||||
matLayoutA,
|
|
||||||
matLayoutB,
|
|
||||||
matLayoutC,
|
|
||||||
&algo,
|
|
||||||
contiguous_group_desc,
|
|
||||||
numExperts,
|
|
||||||
blocksizeM,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
mctlassExContiguousGroupedGemmBasic(handle,
|
|
||||||
mctlass_desc,
|
|
||||||
ptrA,
|
|
||||||
matLayoutA,
|
|
||||||
ptrB,
|
|
||||||
matLayoutB,
|
|
||||||
ptrC,
|
|
||||||
matLayoutC,
|
|
||||||
contiguous_group_desc,
|
|
||||||
&algo,
|
|
||||||
nullptr,
|
|
||||||
0,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
mctlassExHandleDestroy(handle);
|
|
||||||
mctlassExMatrixLayoutDestroy(matLayoutA);
|
|
||||||
mctlassExMatrixLayoutDestroy(matLayoutB);
|
|
||||||
mctlassExMatrixLayoutDestroy(matLayoutC);
|
|
||||||
mctlassExContiguousGroupedDescDestroy(contiguous_group_desc);
|
|
||||||
mctlassExDestroyDesc(mctlass_desc);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename ElementA, typename ElementB, typename ElementC>
|
|
||||||
class McMoeHelper {
|
|
||||||
public:
|
|
||||||
McMoeHelper(const std::string gemm_method) : gemm_method_(gemm_method) {}
|
|
||||||
|
|
||||||
// -------- getWorkspaceSize -------- //
|
|
||||||
template <typename KeyT>
|
|
||||||
size_t getWorkspaceSize(const int64_t num_rows,
|
|
||||||
const int64_t hidden_size,
|
|
||||||
const int64_t inter_size,
|
|
||||||
const int64_t num_experts,
|
|
||||||
const int64_t k) {
|
|
||||||
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
|
||||||
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
|
||||||
const size_t padded_experts = AlignTo16(num_experts);
|
|
||||||
const size_t num_moe_inputs = AlignTo16(k * num_rows);
|
|
||||||
// softmax output, permuted_rows and permuted_experts have moved to outside
|
|
||||||
// of moe kernel, allocate them in Encoder or Decoder before invoking
|
|
||||||
// FfnLayer forward.
|
|
||||||
size_t total_ws_bytes =
|
|
||||||
5 * num_moe_inputs *
|
|
||||||
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
|
|
||||||
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
|
|
||||||
total_ws_bytes +=
|
|
||||||
padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_
|
|
||||||
|
|
||||||
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
|
|
||||||
const size_t sorter_ws_size_bytes =
|
|
||||||
AlignTo16(sorter_.getWorkspaceSize(num_rows));
|
|
||||||
sorter_.update_num_experts(num_experts);
|
|
||||||
|
|
||||||
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
|
|
||||||
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
|
|
||||||
int64_t remaining_bytes =
|
|
||||||
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
|
|
||||||
bytes_for_intermediate_and_sorting += remaining_bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
total_ws_bytes +=
|
|
||||||
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
|
|
||||||
// sorting workspace
|
|
||||||
|
|
||||||
int64_t num_softmax_outs = 0;
|
|
||||||
const bool is_pow_2 =
|
|
||||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
|
||||||
if (!is_pow_2 || num_experts > 256) {
|
|
||||||
num_softmax_outs = AlignTo16(num_rows * num_experts);
|
|
||||||
}
|
|
||||||
|
|
||||||
total_ws_bytes += num_softmax_outs * sizeof(float);
|
|
||||||
|
|
||||||
return total_ws_bytes;
|
|
||||||
}
|
|
||||||
|
|
||||||
void computeFFN(const paddle::Tensor* input,
|
|
||||||
const paddle::Tensor* gate_weight,
|
|
||||||
const paddle::Tensor* ffn1_weight,
|
|
||||||
const paddle::Tensor* ffn1_scale,
|
|
||||||
const paddle::Tensor* ffn1_bias,
|
|
||||||
const paddle::Tensor* ffn2_weight,
|
|
||||||
const paddle::Tensor* ffn2_scale,
|
|
||||||
const paddle::Tensor* ffn2_bias,
|
|
||||||
const paddle::Tensor* moe_token_type_ids,
|
|
||||||
const int moe_topk,
|
|
||||||
const bool group_moe,
|
|
||||||
const bool norm_topk_prob,
|
|
||||||
const float routed_scaling_factor,
|
|
||||||
const std::string moe_type,
|
|
||||||
paddle::Tensor* output) {
|
|
||||||
auto* input_activations = input->data<T>();
|
|
||||||
auto* gating_weights = gate_weight->data<float>();
|
|
||||||
const T* fc1_expert_biases = ffn1_bias ? ffn1_bias->data<T>() : nullptr;
|
|
||||||
const T* fc2_expert_biases = ffn2_bias ? ffn2_bias->data<T>() : nullptr;
|
|
||||||
|
|
||||||
auto* output_ = output->data<T>();
|
|
||||||
auto stream = input->stream();
|
|
||||||
auto place = input->place();
|
|
||||||
auto input_type = input->dtype();
|
|
||||||
|
|
||||||
auto input_dims = input->dims();
|
|
||||||
auto ffn1_dims = ffn1_weight->dims();
|
|
||||||
int64_t token_num = 0;
|
|
||||||
if (input_dims.size() == 3) {
|
|
||||||
token_num = input_dims[0] * input_dims[1];
|
|
||||||
} else {
|
|
||||||
token_num = input_dims[0];
|
|
||||||
}
|
|
||||||
const int64_t num_rows = token_num;
|
|
||||||
|
|
||||||
const int64_t hidden_size = ffn1_dims[2];
|
|
||||||
int64_t inter_dim = 0;
|
|
||||||
if (moe_type == "qkv") {
|
|
||||||
inter_dim = ffn1_dims[2] * ffn1_dims[3] * ffn1_dims[4];
|
|
||||||
} else {
|
|
||||||
inter_dim = ffn1_dims[1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// if (gemm_method == "weight_only_int4") {
|
|
||||||
// inter_dim = inter_dim * 2;
|
|
||||||
// }
|
|
||||||
|
|
||||||
const int64_t inter_size = inter_dim;
|
|
||||||
const int64_t num_experts = ffn1_dims[0];
|
|
||||||
const int64_t k = moe_topk;
|
|
||||||
|
|
||||||
int64_t bytes =
|
|
||||||
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
|
|
||||||
|
|
||||||
// Pointers
|
|
||||||
int* expert_for_source_row;
|
|
||||||
int* source_rows_;
|
|
||||||
int* permuted_rows_;
|
|
||||||
int* permuted_experts_;
|
|
||||||
int* expanded_source_row_to_expanded_dest_row;
|
|
||||||
|
|
||||||
T* permuted_data_;
|
|
||||||
int32_t* total_rows_before_expert_;
|
|
||||||
T* fc1_result_;
|
|
||||||
float* softmax_out_;
|
|
||||||
|
|
||||||
paddle::Tensor ws_ptr_tensor =
|
|
||||||
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
|
|
||||||
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
|
|
||||||
|
|
||||||
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
|
|
||||||
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
|
|
||||||
const int64_t padded_experts = AlignTo16(num_experts);
|
|
||||||
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
|
|
||||||
|
|
||||||
expert_for_source_row = reinterpret_cast<int*>(ws_ptr);
|
|
||||||
source_rows_ = expert_for_source_row + num_moe_inputs;
|
|
||||||
permuted_rows_ = source_rows_ + num_moe_inputs;
|
|
||||||
permuted_experts_ = permuted_rows_ + num_moe_inputs;
|
|
||||||
expanded_source_row_to_expanded_dest_row =
|
|
||||||
permuted_experts_ + num_moe_inputs;
|
|
||||||
permuted_data_ = reinterpret_cast<T*>(
|
|
||||||
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
|
|
||||||
total_rows_before_expert_ =
|
|
||||||
reinterpret_cast<int32_t*>(permuted_data_ + buf_size);
|
|
||||||
fc1_result_ =
|
|
||||||
reinterpret_cast<T*>(total_rows_before_expert_ + padded_experts);
|
|
||||||
|
|
||||||
const bool is_pow_2 =
|
|
||||||
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
|
||||||
if (!is_pow_2 || num_experts > 256) {
|
|
||||||
softmax_out_ = reinterpret_cast<float*>(fc1_result_ + interbuf_size);
|
|
||||||
} else {
|
|
||||||
softmax_out_ = nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
paddle::Tensor expert_scales_float_tensor =
|
|
||||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
|
||||||
float* expert_scales_float = expert_scales_float_tensor.data<float>();
|
|
||||||
|
|
||||||
float* softmax_max_prob = nullptr;
|
|
||||||
if (group_moe) {
|
|
||||||
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
|
|
||||||
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
|
||||||
// (TODO: check fill success ?)
|
|
||||||
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
|
|
||||||
softmax_max_prob = softmax_max_prob_tensor.data<float>();
|
|
||||||
}
|
|
||||||
|
|
||||||
paddle::Tensor fc1_out_tensor =
|
|
||||||
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
|
|
||||||
T* fc1_out = fc1_out_tensor.data<T>();
|
|
||||||
|
|
||||||
auto input_cast_tensor =
|
|
||||||
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
|
|
||||||
auto gate_tensor =
|
|
||||||
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
|
|
||||||
float* gating_output = gate_tensor.data<float>();
|
|
||||||
|
|
||||||
if (moe_token_type_ids) {
|
|
||||||
auto* moe_token_type_ids_out = moe_token_type_ids->data<int>();
|
|
||||||
moe_token_type_ids_kernelLauncher<float>(gating_output,
|
|
||||||
moe_token_type_ids_out,
|
|
||||||
num_rows,
|
|
||||||
num_experts,
|
|
||||||
k,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
|
|
||||||
topk_gating_softmax_kernelLauncher<float>(gating_output,
|
|
||||||
expert_scales_float,
|
|
||||||
softmax_out_,
|
|
||||||
expert_for_source_row,
|
|
||||||
source_rows_,
|
|
||||||
softmax_max_prob,
|
|
||||||
num_rows,
|
|
||||||
num_experts,
|
|
||||||
k,
|
|
||||||
group_moe,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
const int64_t sorter_ws_size_bytes =
|
|
||||||
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
|
|
||||||
|
|
||||||
sorter_.run(fc1_result_,
|
|
||||||
sorter_ws_size_bytes,
|
|
||||||
expert_for_source_row,
|
|
||||||
permuted_experts_,
|
|
||||||
source_rows_,
|
|
||||||
permuted_rows_,
|
|
||||||
k * num_rows,
|
|
||||||
false,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
initialize_moe_routing_kernelLauncher(
|
|
||||||
input_activations,
|
|
||||||
permuted_data_,
|
|
||||||
permuted_rows_,
|
|
||||||
expanded_source_row_to_expanded_dest_row,
|
|
||||||
num_rows,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
k,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
const int64_t expanded_active_expert_rows = k * num_rows;
|
|
||||||
|
|
||||||
compute_total_rows_before_expert(permuted_experts_,
|
|
||||||
expanded_active_expert_rows,
|
|
||||||
num_experts,
|
|
||||||
total_rows_before_expert_,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR;
|
|
||||||
mctlassExOrder_t column_major =
|
|
||||||
mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR;
|
|
||||||
auto m_num_tile =
|
|
||||||
GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
|
|
||||||
int* m_num_tile_ptr = reinterpret_cast<int*>(m_num_tile.data<int>());
|
|
||||||
|
|
||||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
|
||||||
reinterpret_cast<const ElementA*>(permuted_data_),
|
|
||||||
row_major,
|
|
||||||
reinterpret_cast<const ElementB*>(ffn1_weight->data<ElementB>()),
|
|
||||||
column_major,
|
|
||||||
reinterpret_cast<const ElementA*>(ffn1_scale->data<T>()),
|
|
||||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
|
||||||
reinterpret_cast<ElementC*>(fc1_out),
|
|
||||||
row_major,
|
|
||||||
total_rows_before_expert_,
|
|
||||||
m_num_tile_ptr,
|
|
||||||
num_experts,
|
|
||||||
expanded_active_expert_rows,
|
|
||||||
inter_size,
|
|
||||||
hidden_size,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
if (moe_type == "ffn") {
|
|
||||||
auto act_out_tensor =
|
|
||||||
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
|
||||||
auto act_out = act_out_tensor.data<T>();
|
|
||||||
|
|
||||||
paddle::Tensor fc2_output_tensor =
|
|
||||||
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
|
|
||||||
T* fc2_result = fc2_output_tensor.data<T>();
|
|
||||||
|
|
||||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
|
||||||
reinterpret_cast<const ElementA*>(act_out),
|
|
||||||
row_major,
|
|
||||||
reinterpret_cast<const ElementB*>(ffn2_weight->data<ElementB>()),
|
|
||||||
column_major,
|
|
||||||
reinterpret_cast<const ElementA*>(ffn2_scale->data<T>()),
|
|
||||||
nullptr,
|
|
||||||
reinterpret_cast<ElementC*>(fc2_result),
|
|
||||||
row_major,
|
|
||||||
total_rows_before_expert_,
|
|
||||||
m_num_tile_ptr,
|
|
||||||
num_experts,
|
|
||||||
expanded_active_expert_rows,
|
|
||||||
hidden_size,
|
|
||||||
inter_size / 2,
|
|
||||||
stream);
|
|
||||||
|
|
||||||
finalize_moe_routing_kernelLauncher(
|
|
||||||
fc2_result,
|
|
||||||
output_,
|
|
||||||
fc2_expert_biases,
|
|
||||||
reinterpret_cast<float*>(expert_scales_float),
|
|
||||||
expanded_source_row_to_expanded_dest_row,
|
|
||||||
expert_for_source_row,
|
|
||||||
num_rows,
|
|
||||||
hidden_size,
|
|
||||||
k,
|
|
||||||
static_cast<int>(1),
|
|
||||||
norm_topk_prob,
|
|
||||||
routed_scaling_factor,
|
|
||||||
stream);
|
|
||||||
} else {
|
|
||||||
finalize_moe_routing_kernelLauncher(
|
|
||||||
// fc2_result,
|
|
||||||
fc1_out,
|
|
||||||
output_,
|
|
||||||
fc1_expert_biases, // fc2_expert_biases,
|
|
||||||
reinterpret_cast<float*>(expert_scales_float),
|
|
||||||
expanded_source_row_to_expanded_dest_row,
|
|
||||||
expert_for_source_row,
|
|
||||||
num_rows,
|
|
||||||
inter_size,
|
|
||||||
k,
|
|
||||||
static_cast<int>(0),
|
|
||||||
norm_topk_prob,
|
|
||||||
routed_scaling_factor,
|
|
||||||
stream);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string gemm_method_;
|
|
||||||
CubKeyValueSorter sorter_;
|
|
||||||
};
|
|
||||||
@@ -17,26 +17,35 @@
|
|||||||
#pragma GCC diagnostic ignored "-Wunused-function"
|
#pragma GCC diagnostic ignored "-Wunused-function"
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fused_moe_helper.h"
|
#include "fused_moe_imp_op.h"
|
||||||
#include "fused_moe_op.h"
|
#include "fused_moe_op.h"
|
||||||
#pragma GCC diagnostic pop
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
|
|
||||||
template <paddle::DataType T>
|
template <paddle::DataType T>
|
||||||
void MoeDispatchKernel(const paddle::Tensor& input,
|
void MoeDispatchKernel(
|
||||||
const paddle::Tensor& gating_output,
|
const paddle::Tensor& input,
|
||||||
const int moe_topk,
|
const paddle::Tensor& gating_output,
|
||||||
const bool group_moe,
|
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||||
const bool topk_only_mode,
|
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||||
const int num_rows,
|
const int moe_topk,
|
||||||
const int hidden_size,
|
const bool group_moe,
|
||||||
const int expert_num,
|
const bool topk_only_mode,
|
||||||
paddle::Tensor* permute_input,
|
const int num_rows,
|
||||||
paddle::Tensor* tokens_expert_prefix_sum,
|
const int hidden_size,
|
||||||
paddle::Tensor* permute_indices_per_token,
|
const int expert_num,
|
||||||
paddle::Tensor* top_k_weight,
|
paddle::Tensor* permute_input,
|
||||||
paddle::Tensor* top_k_indices) {
|
paddle::Tensor* tokens_expert_prefix_sum,
|
||||||
|
paddle::Tensor* permute_indices_per_token,
|
||||||
|
paddle::Tensor* topk_weight,
|
||||||
|
paddle::Tensor* topk_idx,
|
||||||
|
paddle::Tensor* expert_idx_per_token) {
|
||||||
|
using namespace phi;
|
||||||
|
|
||||||
|
if (num_rows == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -78,7 +87,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
reinterpret_cast<int*>(sorter_ws_ptr + sorter_ws_size_bytes);
|
||||||
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
int* permuted_rows_ = permuted_experts_ + num_moe_inputs;
|
||||||
|
|
||||||
int* expert_for_source_row = top_k_indices->data<int>();
|
int* topk_idx_ptr = topk_idx->data<int>();
|
||||||
|
|
||||||
float* softmax_max_prob = nullptr;
|
float* softmax_max_prob = nullptr;
|
||||||
if (group_moe) {
|
if (group_moe) {
|
||||||
@@ -103,23 +112,25 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
softmax_out_ = nullptr;
|
softmax_out_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
topk_gating_softmax_kernelLauncher<float>(gating_output.data<float>(),
|
topk_gating_softmax_kernelLauncher(
|
||||||
top_k_weight->data<float>(),
|
gating_output.data<float>(),
|
||||||
softmax_out_,
|
static_cast<const float*>(nullptr), // no gating_correction_bias
|
||||||
expert_for_source_row,
|
topk_weight->data<float>(),
|
||||||
source_rows_,
|
softmax_out_,
|
||||||
softmax_max_prob,
|
topk_idx_ptr,
|
||||||
num_rows,
|
source_rows_,
|
||||||
expert_num,
|
softmax_max_prob,
|
||||||
moe_topk,
|
num_rows,
|
||||||
group_moe,
|
expert_num,
|
||||||
stream,
|
moe_topk,
|
||||||
topk_only_mode);
|
group_moe,
|
||||||
|
stream,
|
||||||
|
topk_only_mode);
|
||||||
|
|
||||||
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
sorter_.run(reinterpret_cast<void*>(sorter_ws_ptr),
|
||||||
sorter_ws_size_bytes,
|
sorter_ws_size_bytes,
|
||||||
expert_for_source_row,
|
topk_idx_ptr,
|
||||||
permuted_experts_,
|
expert_idx_per_token->data<int32_t>(),
|
||||||
source_rows_,
|
source_rows_,
|
||||||
permuted_rows_,
|
permuted_rows_,
|
||||||
moe_topk * num_rows,
|
moe_topk * num_rows,
|
||||||
@@ -130,6 +141,8 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
input.data<data_t>(),
|
input.data<data_t>(),
|
||||||
permute_input->data<data_t>(),
|
permute_input->data<data_t>(),
|
||||||
permuted_rows_,
|
permuted_rows_,
|
||||||
|
expert_idx_per_token->data<int32_t>(),
|
||||||
|
nullptr,
|
||||||
permute_indices_per_token->data<int32_t>(),
|
permute_indices_per_token->data<int32_t>(),
|
||||||
num_rows,
|
num_rows,
|
||||||
num_rows,
|
num_rows,
|
||||||
@@ -137,7 +150,7 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
moe_topk,
|
moe_topk,
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
compute_total_rows_before_expert(permuted_experts_,
|
compute_total_rows_before_expert(expert_idx_per_token->data<int32_t>(),
|
||||||
moe_topk * num_rows,
|
moe_topk * num_rows,
|
||||||
expert_num,
|
expert_num,
|
||||||
tokens_expert_prefix_sum->data<int32_t>(),
|
tokens_expert_prefix_sum->data<int32_t>(),
|
||||||
@@ -147,8 +160,11 @@ void MoeDispatchKernel(const paddle::Tensor& input,
|
|||||||
std::vector<paddle::Tensor> MoeExpertDispatch(
|
std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||||
const paddle::Tensor& input,
|
const paddle::Tensor& input,
|
||||||
const paddle::Tensor& gating_output,
|
const paddle::Tensor& gating_output,
|
||||||
|
const paddle::optional<paddle::Tensor>& gating_correction_bias,
|
||||||
|
const paddle::optional<paddle::Tensor>& w4a8_in_scale,
|
||||||
const int moe_topk,
|
const int moe_topk,
|
||||||
const bool group_moe,
|
const bool group_moe,
|
||||||
|
const std::string& moe_quant_type,
|
||||||
const bool topk_only_mode) {
|
const bool topk_only_mode) {
|
||||||
const auto input_type = input.dtype();
|
const auto input_type = input.dtype();
|
||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
@@ -168,9 +184,9 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
auto permute_input =
|
auto permute_input =
|
||||||
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
GetEmptyTensor({moe_topk * num_rows, hidden_size}, input_type, place);
|
||||||
// correspond to the weighted coefficients of the results from each expert.
|
// correspond to the weighted coefficients of the results from each expert.
|
||||||
auto top_k_weight =
|
auto topk_weight =
|
||||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
|
||||||
auto top_k_indices =
|
auto topk_idx =
|
||||||
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
auto tokens_expert_prefix_sum =
|
auto tokens_expert_prefix_sum =
|
||||||
@@ -178,18 +194,24 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
auto permute_indices_per_token =
|
auto permute_indices_per_token =
|
||||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
|
auto expert_idx_per_token =
|
||||||
|
GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place);
|
||||||
|
|
||||||
if (token_rows == 0) {
|
if (token_rows == 0) {
|
||||||
return {permute_input,
|
return {permute_input,
|
||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
top_k_weight,
|
topk_weight,
|
||||||
top_k_indices};
|
topk_idx,
|
||||||
|
expert_idx_per_token};
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||||
gating_output,
|
gating_output,
|
||||||
|
gating_correction_bias,
|
||||||
|
w4a8_in_scale,
|
||||||
moe_topk,
|
moe_topk,
|
||||||
group_moe,
|
group_moe,
|
||||||
topk_only_mode,
|
topk_only_mode,
|
||||||
@@ -199,37 +221,25 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
|||||||
&permute_input,
|
&permute_input,
|
||||||
&tokens_expert_prefix_sum,
|
&tokens_expert_prefix_sum,
|
||||||
&permute_indices_per_token,
|
&permute_indices_per_token,
|
||||||
&top_k_weight,
|
&topk_weight,
|
||||||
&top_k_indices);
|
&topk_idx,
|
||||||
|
&expert_idx_per_token);
|
||||||
break;
|
break;
|
||||||
// case paddle::DataType::FLOAT16:
|
|
||||||
// MoeDispatchKernel<paddle::DataType::FLOAT16>(input,
|
|
||||||
// gating_output,
|
|
||||||
// moe_topk,
|
|
||||||
// group_moe,
|
|
||||||
// topk_only_mode,
|
|
||||||
// num_rows,
|
|
||||||
// hidden_size,
|
|
||||||
// expert_num,
|
|
||||||
// &permute_input,
|
|
||||||
// &tokens_expert_prefix_sum,
|
|
||||||
// &permute_indices_per_token,
|
|
||||||
// &top_k_weight,
|
|
||||||
// &top_k_indices);
|
|
||||||
// break;
|
|
||||||
default:
|
default:
|
||||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
PD_THROW("Unsupported data type for MoeDispatchKernel");
|
||||||
}
|
}
|
||||||
return {permute_input,
|
return {permute_input,
|
||||||
tokens_expert_prefix_sum,
|
tokens_expert_prefix_sum,
|
||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
top_k_weight,
|
topk_weight,
|
||||||
top_k_indices};
|
topk_idx,
|
||||||
|
expert_idx_per_token};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||||
const std::vector<int64_t>& input_shape,
|
const std::vector<int64_t>& input_shape,
|
||||||
const std::vector<int64_t>& gating_output_shape,
|
const std::vector<int64_t>& gating_output_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& bias_shape,
|
||||||
const int moe_topk) {
|
const int moe_topk) {
|
||||||
int token_rows = -1;
|
int token_rows = -1;
|
||||||
|
|
||||||
@@ -241,33 +251,44 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
|||||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||||
const int num_rows = token_rows;
|
const int num_rows = token_rows;
|
||||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||||
|
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
|
||||||
|
|
||||||
return {{moe_topk * num_rows, hidden_size},
|
return {{permuted_rows, hidden_size},
|
||||||
{expert_num},
|
{expert_num},
|
||||||
{moe_topk, num_rows},
|
{moe_topk, num_rows},
|
||||||
{num_rows, moe_topk},
|
{num_rows, moe_topk},
|
||||||
{num_rows, moe_topk}};
|
{num_rows, moe_topk},
|
||||||
|
{permuted_rows}};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
std::vector<paddle::DataType> MoeExpertDispatchInferDtype(
|
||||||
const paddle::DataType& input_dtype,
|
const paddle::DataType& input_dtype,
|
||||||
const paddle::DataType& gating_output_dtype,
|
const paddle::DataType& gating_output_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& bias_type,
|
||||||
const int moe_topk) {
|
const int moe_topk) {
|
||||||
return {input_dtype,
|
return {input_dtype,
|
||||||
paddle::DataType::INT64,
|
paddle::DataType::INT64,
|
||||||
paddle::DataType::INT32,
|
paddle::DataType::INT32,
|
||||||
paddle::DataType::FLOAT32,
|
paddle::DataType::FLOAT32,
|
||||||
|
paddle::DataType::INT32,
|
||||||
paddle::DataType::INT32};
|
paddle::DataType::INT32};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(moe_expert_dispatch)
|
PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||||
.Inputs({"input", "gating_output"})
|
.Inputs({"input",
|
||||||
|
"gating_output",
|
||||||
|
paddle::Optional("gating_correction_bias"),
|
||||||
|
paddle::Optional("w4a8_in_scale")})
|
||||||
.Outputs({"permute_input",
|
.Outputs({"permute_input",
|
||||||
"tokens_expert_prefix_sum",
|
"tokens_expert_prefix_sum",
|
||||||
"permute_indices_per_token",
|
"permute_indices_per_token",
|
||||||
"top_k_weight",
|
"topk_weight",
|
||||||
"top_k_indices"})
|
"topk_idx",
|
||||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
"expert_idx_per_token"})
|
||||||
|
.Attrs({"moe_topk:int",
|
||||||
|
"group_moe:bool",
|
||||||
|
"moe_quant_type:std::string",
|
||||||
|
"topk_only_mode:bool"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));
|
||||||
|
|||||||
@@ -12,23 +12,21 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
// BUILD_MARK
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "fused_moe_helper.h"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include "mc_fused_moe_helper.h"
|
|
||||||
|
|
||||||
template <paddle::DataType T,
|
template <paddle::DataType T>
|
||||||
typename ElementA,
|
void MoeFFNKernel(paddle::Tensor& permute_input,
|
||||||
typename ElementB,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
typename ElementC>
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
void McMoeFFNKernel(paddle::Tensor& permute_input,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::Tensor& ffn1_weight,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::Tensor& ffn2_weight,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
const std::string& quant_method) {
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
using namespace phi;
|
||||||
const std::string& quant_method) {
|
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -38,11 +36,13 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
|
|||||||
auto input_type = permute_input.dtype();
|
auto input_type = permute_input.dtype();
|
||||||
auto stream = permute_input.stream();
|
auto stream = permute_input.stream();
|
||||||
|
|
||||||
|
auto int8_moe_gemm_runner = McMoeGemmRunner<DataType_, int8_t>();
|
||||||
|
|
||||||
const int expanded_active_expert_rows =
|
const int expanded_active_expert_rows =
|
||||||
permute_input.dims()[0]; // permute_input.dims(): m, k
|
permute_input.dims()[0]; // permute_input.dims(): m, k
|
||||||
const int num_experts = ffn1_weight.dims()[0]; // batchsize
|
const int num_experts = up_gate_proj_weight.dims()[0]; // batchsize
|
||||||
const int hidden_size = ffn1_weight.dims()[2]; // n
|
const int hidden_size = up_gate_proj_weight.dims()[2]; // n
|
||||||
int inter_dim = ffn1_weight.dims()[1]; // k
|
int inter_dim = up_gate_proj_weight.dims()[1]; // k
|
||||||
|
|
||||||
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
const int64_t inter_size = inter_dim; // since weight_only_int_8
|
||||||
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
paddle::Tensor fc1_out_tensor = GetEmptyTensor(
|
||||||
@@ -58,60 +58,71 @@ void McMoeFFNKernel(paddle::Tensor& permute_input,
|
|||||||
|
|
||||||
// ffn1
|
// ffn1
|
||||||
auto fc1_expert_biases =
|
auto fc1_expert_biases =
|
||||||
ffn1_bias
|
up_gate_proj_bias
|
||||||
? const_cast<paddle::Tensor*>(ffn1_bias.get_ptr())->data<data_t>()
|
? const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr())
|
||||||
|
->data<data_t>()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
auto fc1_expert_scales =
|
auto fc1_expert_scales =
|
||||||
const_cast<paddle::Tensor*>(ffn1_scale.get_ptr())->data<data_t>();
|
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr())->data<data_t>();
|
||||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
if (quant_method == "weight_only_int8") {
|
||||||
reinterpret_cast<const ElementA*>(permuted_input_ptr),
|
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
|
||||||
row_major,
|
reinterpret_cast<const DataType_*>(permuted_input_ptr),
|
||||||
reinterpret_cast<const ElementB*>(ffn1_weight.data<ElementB>()),
|
row_major,
|
||||||
column_major,
|
reinterpret_cast<const int8_t*>(up_gate_proj_weight.data<int8_t>()),
|
||||||
reinterpret_cast<const ElementA*>(fc1_expert_scales),
|
column_major,
|
||||||
reinterpret_cast<const ElementA*>(fc1_expert_biases),
|
reinterpret_cast<const DataType_*>(fc1_expert_scales),
|
||||||
reinterpret_cast<ElementC*>(fc1_out_ptr),
|
reinterpret_cast<const DataType_*>(fc1_expert_biases),
|
||||||
row_major,
|
reinterpret_cast<DataType_*>(fc1_out_ptr),
|
||||||
tokens_expert_prefix_sum.data<int>(),
|
row_major,
|
||||||
m_num_tile_ptr,
|
tokens_expert_prefix_sum.data<int>(),
|
||||||
num_experts,
|
m_num_tile_ptr,
|
||||||
expanded_active_expert_rows,
|
num_experts,
|
||||||
inter_dim,
|
expanded_active_expert_rows,
|
||||||
hidden_size,
|
inter_dim,
|
||||||
stream);
|
hidden_size,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported gemm method: " + quant_method);
|
||||||
|
}
|
||||||
|
|
||||||
// swiglu
|
// swiglu
|
||||||
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
auto act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr);
|
||||||
auto act_out = act_out_tensor.data<data_t>();
|
auto act_out = act_out_tensor.data<data_t>();
|
||||||
|
|
||||||
auto fc2_expert_scales =
|
auto fc2_expert_scales =
|
||||||
const_cast<paddle::Tensor*>(ffn2_scale.get_ptr())->data<data_t>();
|
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr())->data<data_t>();
|
||||||
mc_grouped_gemm_basic_kernel<ElementA, ElementB, ElementC>(
|
|
||||||
reinterpret_cast<const ElementA*>(act_out),
|
if (quant_method == "weight_only_int8") {
|
||||||
row_major,
|
int8_moe_gemm_runner.mc_grouped_gemm_basic_kernel(
|
||||||
reinterpret_cast<const ElementB*>(ffn2_weight.data<ElementB>()),
|
reinterpret_cast<const DataType_*>(act_out),
|
||||||
column_major,
|
row_major,
|
||||||
reinterpret_cast<const ElementA*>(fc2_expert_scales),
|
reinterpret_cast<const int8_t*>(down_proj_weight.data<int8_t>()),
|
||||||
nullptr,
|
column_major,
|
||||||
reinterpret_cast<ElementC*>(permuted_input_ptr),
|
reinterpret_cast<const DataType_*>(fc2_expert_scales),
|
||||||
row_major,
|
nullptr,
|
||||||
tokens_expert_prefix_sum.data<int>(),
|
reinterpret_cast<DataType_*>(permuted_input_ptr),
|
||||||
m_num_tile_ptr,
|
row_major,
|
||||||
num_experts,
|
tokens_expert_prefix_sum.data<int>(),
|
||||||
expanded_active_expert_rows,
|
m_num_tile_ptr,
|
||||||
hidden_size,
|
num_experts,
|
||||||
inter_dim / 2,
|
expanded_active_expert_rows,
|
||||||
stream);
|
hidden_size,
|
||||||
|
inter_dim / 2,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unsupported gemm method: " + quant_method);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::Tensor> MoeExpertFFN(
|
std::vector<paddle::Tensor> MoeExpertFFN(
|
||||||
paddle::Tensor& permute_input,
|
paddle::Tensor& permute_input,
|
||||||
const paddle::Tensor& tokens_expert_prefix_sum,
|
const paddle::Tensor& tokens_expert_prefix_sum,
|
||||||
const paddle::Tensor& ffn1_weight,
|
const paddle::Tensor& up_gate_proj_weight,
|
||||||
const paddle::Tensor& ffn2_weight,
|
const paddle::Tensor& down_proj_weight,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_bias,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
|
||||||
const paddle::optional<paddle::Tensor>& ffn1_scale,
|
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_scale,
|
const paddle::optional<paddle::Tensor>& down_proj_scale,
|
||||||
|
const paddle::optional<paddle::Tensor>& expert_idx_per_token,
|
||||||
const std::string& quant_method) {
|
const std::string& quant_method) {
|
||||||
assert(quant_method == "weight_only_int8");
|
assert(quant_method == "weight_only_int8");
|
||||||
const auto input_type = permute_input.dtype();
|
const auto input_type = permute_input.dtype();
|
||||||
@@ -122,31 +133,18 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
McMoeFFNKernel<paddle::DataType::BFLOAT16,
|
MoeFFNKernel<paddle::DataType::BFLOAT16>(permute_input,
|
||||||
maca_bfloat16,
|
tokens_expert_prefix_sum,
|
||||||
int8_t,
|
up_gate_proj_weight,
|
||||||
maca_bfloat16>(permute_input,
|
down_proj_weight,
|
||||||
tokens_expert_prefix_sum,
|
up_gate_proj_bias,
|
||||||
ffn1_weight,
|
up_gate_proj_scale,
|
||||||
ffn2_weight,
|
down_proj_scale,
|
||||||
ffn1_bias,
|
expert_idx_per_token,
|
||||||
ffn1_scale,
|
quant_method);
|
||||||
ffn2_scale,
|
|
||||||
quant_method);
|
|
||||||
break;
|
break;
|
||||||
// case paddle::DataType::FLOAT16:
|
|
||||||
// MoeFFNKernel<paddle::DataType::FLOAT16>(permute_input,
|
|
||||||
// tokens_expert_prefix_sum,
|
|
||||||
// ffn1_weight,
|
|
||||||
// ffn2_weight,
|
|
||||||
// ffn1_bias,
|
|
||||||
// ffn1_scale,
|
|
||||||
// ffn2_scale,
|
|
||||||
// quant_method,
|
|
||||||
// ffn_out);
|
|
||||||
// break;
|
|
||||||
default:
|
default:
|
||||||
PD_THROW("Unsupported data type for MoeExpertFFN");
|
PD_THROW("Unsupported data type for MoeFFNhKernel");
|
||||||
}
|
}
|
||||||
return {permute_input};
|
return {permute_input};
|
||||||
}
|
}
|
||||||
@@ -154,33 +152,37 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
|||||||
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
std::vector<std::vector<int64_t>> MoeExpertFFNInferShape(
|
||||||
const std::vector<int64_t>& permute_input_shape,
|
const std::vector<int64_t>& permute_input_shape,
|
||||||
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
|
||||||
const std::vector<int64_t>& ffn1_weight_shape,
|
const std::vector<int64_t>& up_gate_proj_weight_shape,
|
||||||
const std::vector<int64_t>& ffn2_weight_shape,
|
const std::vector<int64_t>& down_proj_weight_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
|
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
|
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& expert_idx_per_token_shape,
|
||||||
|
const std::string& quant_method) {
|
||||||
return {permute_input_shape};
|
return {permute_input_shape};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
std::vector<paddle::DataType> MoeExpertFFNInferDtype(
|
||||||
const paddle::DataType& permute_input_dtype,
|
const paddle::DataType& permute_input_dtype,
|
||||||
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
const paddle::DataType& tokens_expert_prefix_sum_dtype,
|
||||||
const paddle::DataType& ffn1_weight_dtype,
|
const paddle::DataType& up_gate_proj_weight_dtype,
|
||||||
const paddle::DataType& ffn2_weight_dtype,
|
const paddle::DataType& down_proj_weight_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
|
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
|
const paddle::optional<paddle::DataType>& down_proj_scale_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& expert_idx_per_token_dtype) {
|
||||||
return {permute_input_dtype};
|
return {permute_input_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_OP(moe_expert_ffn)
|
PD_BUILD_OP(moe_expert_ffn)
|
||||||
.Inputs({"permute_input",
|
.Inputs({"permute_input",
|
||||||
"tokens_expert_prefix_sum",
|
"tokens_expert_prefix_sum",
|
||||||
"ffn1_weight",
|
"up_gate_proj_weight",
|
||||||
"ffn2_weight",
|
"down_proj_weight",
|
||||||
paddle::Optional("ffn1_bias"),
|
paddle::Optional("up_gate_proj_bias"),
|
||||||
paddle::Optional("ffn1_scale"),
|
paddle::Optional("up_gate_proj_scale"),
|
||||||
paddle::Optional("ffn2_scale")})
|
paddle::Optional("down_proj_scale"),
|
||||||
|
paddle::Optional("expert_idx_per_token")})
|
||||||
.Outputs({"output_tensor"})
|
.Outputs({"output_tensor"})
|
||||||
.Attrs({"quant_method:std::string"})
|
.Attrs({"quant_method:std::string"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
.SetKernelFn(PD_KERNEL(MoeExpertFFN))
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fused_moe_helper.h"
|
|
||||||
#include "fused_moe_op.h"
|
#include "fused_moe_op.h"
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
|
|
||||||
@@ -23,13 +22,14 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
|||||||
const paddle::Tensor& top_k_weight,
|
const paddle::Tensor& top_k_weight,
|
||||||
const paddle::Tensor& permute_indices_per_token,
|
const paddle::Tensor& permute_indices_per_token,
|
||||||
const paddle::Tensor& top_k_indices,
|
const paddle::Tensor& top_k_indices,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||||
const bool norm_topk_prob,
|
const bool norm_topk_prob,
|
||||||
const float routed_scaling_factor,
|
const float routed_scaling_factor,
|
||||||
const int num_rows,
|
const int num_rows,
|
||||||
const int hidden_size,
|
const int hidden_size,
|
||||||
const int topk,
|
const int topk,
|
||||||
paddle::Tensor* output) {
|
paddle::Tensor* output) {
|
||||||
|
using namespace phi;
|
||||||
typedef PDTraits<T> traits_;
|
typedef PDTraits<T> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -38,7 +38,7 @@ void MoeReduceKernel(const paddle::Tensor& ffn_out,
|
|||||||
finalize_moe_routing_kernelLauncher(
|
finalize_moe_routing_kernelLauncher(
|
||||||
ffn_out.data<data_t>(),
|
ffn_out.data<data_t>(),
|
||||||
output->data<data_t>(),
|
output->data<data_t>(),
|
||||||
ffn2_bias ? ffn2_bias->data<data_t>() : nullptr,
|
down_proj_bias ? down_proj_bias->data<data_t>() : nullptr,
|
||||||
top_k_weight.data<float>(),
|
top_k_weight.data<float>(),
|
||||||
permute_indices_per_token.data<int32_t>(),
|
permute_indices_per_token.data<int32_t>(),
|
||||||
top_k_indices.data<int>(),
|
top_k_indices.data<int>(),
|
||||||
@@ -56,7 +56,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
|||||||
const paddle::Tensor& top_k_weight,
|
const paddle::Tensor& top_k_weight,
|
||||||
const paddle::Tensor& permute_indices_per_token,
|
const paddle::Tensor& permute_indices_per_token,
|
||||||
const paddle::Tensor& top_k_indices,
|
const paddle::Tensor& top_k_indices,
|
||||||
const paddle::optional<paddle::Tensor>& ffn2_bias,
|
const paddle::optional<paddle::Tensor>& down_proj_bias,
|
||||||
const bool norm_topk_prob,
|
const bool norm_topk_prob,
|
||||||
const float routed_scaling_factor) {
|
const float routed_scaling_factor) {
|
||||||
const auto input_type = ffn_out.dtype();
|
const auto input_type = ffn_out.dtype();
|
||||||
@@ -69,7 +69,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
|||||||
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place);
|
||||||
|
|
||||||
// Avoids ‘invalid configuration argument’ when we launch the kernel.
|
// Avoids ‘invalid configuration argument’ when we launch the kernel.
|
||||||
if (ffn_out.dims()[0] == 0) return {output};
|
if (num_rows == 0) return {output};
|
||||||
|
|
||||||
switch (input_type) {
|
switch (input_type) {
|
||||||
case paddle::DataType::BFLOAT16:
|
case paddle::DataType::BFLOAT16:
|
||||||
@@ -77,7 +77,7 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
|||||||
top_k_weight,
|
top_k_weight,
|
||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
top_k_indices,
|
top_k_indices,
|
||||||
ffn2_bias,
|
down_proj_bias,
|
||||||
norm_topk_prob,
|
norm_topk_prob,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
num_rows,
|
num_rows,
|
||||||
@@ -85,21 +85,8 @@ std::vector<paddle::Tensor> MoeExpertReduce(
|
|||||||
topk,
|
topk,
|
||||||
&output);
|
&output);
|
||||||
break;
|
break;
|
||||||
// case paddle::DataType::FLOAT16:
|
|
||||||
// MoeReduceKernel<paddle::DataType::FLOAT16>(ffn_out,
|
|
||||||
// top_k_weight,
|
|
||||||
// permute_indices_per_token,
|
|
||||||
// top_k_indices,
|
|
||||||
// ffn2_bias,
|
|
||||||
// norm_topk_prob,
|
|
||||||
// routed_scaling_factor,
|
|
||||||
// num_rows,
|
|
||||||
// hidden_size,
|
|
||||||
// topk,
|
|
||||||
// &output);
|
|
||||||
// break;
|
|
||||||
default:
|
default:
|
||||||
PD_THROW("Only support bf16 for MoeDispatchKernel");
|
PD_THROW("Unsupported data type for MoeReduceKernel");
|
||||||
}
|
}
|
||||||
return {output};
|
return {output};
|
||||||
}
|
}
|
||||||
@@ -109,7 +96,7 @@ std::vector<std::vector<int64_t>> MoeExpertReduceInferShape(
|
|||||||
const std::vector<int64_t>& top_k_weight_shape,
|
const std::vector<int64_t>& top_k_weight_shape,
|
||||||
const std::vector<int64_t>& permute_indices_per_token_shape,
|
const std::vector<int64_t>& permute_indices_per_token_shape,
|
||||||
const std::vector<int64_t>& top_k_indices_shape,
|
const std::vector<int64_t>& top_k_indices_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape) {
|
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape) {
|
||||||
const int topk = top_k_indices_shape[1];
|
const int topk = top_k_indices_shape[1];
|
||||||
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
|
std::vector<int64_t> fused_moe_out_shape = {ffn_out_shape[0] / topk,
|
||||||
ffn_out_shape[1]};
|
ffn_out_shape[1]};
|
||||||
@@ -122,7 +109,7 @@ std::vector<paddle::DataType> MoeExpertReduceInferDtype(
|
|||||||
const paddle::DataType& top_k_weight_dtype,
|
const paddle::DataType& top_k_weight_dtype,
|
||||||
const paddle::DataType& permute_indices_per_token_dtype,
|
const paddle::DataType& permute_indices_per_token_dtype,
|
||||||
const paddle::DataType& top_k_indices_dtype,
|
const paddle::DataType& top_k_indices_dtype,
|
||||||
const paddle::optional<paddle::DataType>& ffn2_bias_dtype) {
|
const paddle::optional<paddle::DataType>& down_proj_bias_dtype) {
|
||||||
return {ffn_out_dtype};
|
return {ffn_out_dtype};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +118,7 @@ PD_BUILD_OP(moe_expert_reduce)
|
|||||||
"top_k_weight",
|
"top_k_weight",
|
||||||
"permute_indices_per_token",
|
"permute_indices_per_token",
|
||||||
"top_k_indices",
|
"top_k_indices",
|
||||||
paddle::Optional("ffn2_bias")})
|
paddle::Optional("down_proj_bias")})
|
||||||
.Outputs({"output"})
|
.Outputs({"output"})
|
||||||
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
.Attrs({"norm_topk_prob:bool", "routed_scaling_factor:float"})
|
||||||
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
.SetKernelFn(PD_KERNEL(MoeExpertReduce))
|
||||||
|
|||||||
@@ -627,11 +627,17 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
|||||||
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
|
"gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu",
|
||||||
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
"gpu_ops/moe/tritonmoe_preprocess.cu",
|
||||||
"gpu_ops/moe/moe_topk_select.cu",
|
"gpu_ops/moe/moe_topk_select.cu",
|
||||||
|
"gpu_ops/get_img_boundaries.cc",
|
||||||
|
"gpu_ops/remote_cache_kv_ipc.cc",
|
||||||
|
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
|
||||||
|
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
|
||||||
|
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
|
||||||
"metax_ops/moe_dispatch.cu",
|
"metax_ops/moe_dispatch.cu",
|
||||||
"metax_ops/moe_ffn.cu",
|
"metax_ops/moe_ffn.cu",
|
||||||
"metax_ops/moe_reduce.cu",
|
"metax_ops/moe_reduce.cu",
|
||||||
"metax_ops/fused_moe.cu",
|
"metax_ops/fused_moe.cu",
|
||||||
"metax_ops/apply_rope.cu",
|
"metax_ops/apply_rope_qkv.cu",
|
||||||
|
"metax_ops/cache_kv_with_rope.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
sources += find_end_files("gpu_ops/speculate_decoding", ".cu")
|
||||||
@@ -657,6 +663,11 @@ elif paddle.device.is_compiled_with_custom_device("metax_gpu"):
|
|||||||
os.path.join(maca_path, "include"),
|
os.path.join(maca_path, "include"),
|
||||||
os.path.join(maca_path, "include/mcr"),
|
os.path.join(maca_path, "include/mcr"),
|
||||||
os.path.join(maca_path, "include/common"),
|
os.path.join(maca_path, "include/common"),
|
||||||
|
os.path.join(maca_path, "include/mcfft"),
|
||||||
|
os.path.join(maca_path, "include/mcrand"),
|
||||||
|
os.path.join(maca_path, "include/mcsparse"),
|
||||||
|
os.path.join(maca_path, "include/mcblas"),
|
||||||
|
os.path.join(maca_path, "include/mcsolver"),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from fastdeploy.model_executor.layers.backends.metax.attention.flash_attention_i
|
|||||||
flash_attn_kvcache_func,
|
flash_attn_kvcache_func,
|
||||||
flash_attn_unpadded_func,
|
flash_attn_unpadded_func,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.ops.gpu import apply_rope
|
from fastdeploy.model_executor.ops.gpu import apply_rope_qkv, cache_kv_with_rope
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -127,15 +127,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
self.enable_mm = fd_config.model_config.enable_mm
|
self.enable_mm = fd_config.model_config.enable_mm
|
||||||
max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||||
if self.enable_mm:
|
self.attention_metadata.rotary_cos_decode = paddle.empty(
|
||||||
self.attention_metadata.rotary_cos_decode = paddle.empty(
|
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
dtype=self.dtype,
|
||||||
dtype="float32",
|
)
|
||||||
)
|
self.attention_metadata.rotary_sin_decode = paddle.empty(
|
||||||
self.attention_metadata.rotary_sin_decode = paddle.empty(
|
shape=[max_num_seqs, 1, 1, self.head_dim],
|
||||||
shape=[max_num_seqs, 1, 1, self.head_dim],
|
dtype=self.dtype,
|
||||||
dtype="float32",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
@@ -245,6 +244,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids]
|
seq_lens_this_time = forward_meta.seq_lens_this_time[batch_ids]
|
||||||
cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0]
|
cached_kv_lens = forward_meta.seq_lens_decoder[batch_ids, 0]
|
||||||
|
|
||||||
|
self.block_table_prefill = forward_meta.block_tables[batch_ids, :]
|
||||||
|
# mapping token idx to batch idx
|
||||||
|
self.batch_ids_q = paddle.repeat_interleave(
|
||||||
|
paddle.arange(0, batch_ids.shape[0], dtype="int32"), repeats=seq_lens_this_time, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
all_indices = []
|
all_indices = []
|
||||||
for i in range(len(batch_ids)):
|
for i in range(len(batch_ids)):
|
||||||
start_pos = cached_kv_lens[i]
|
start_pos = cached_kv_lens[i]
|
||||||
@@ -285,19 +290,25 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
|
self.attention_metadata.rotary_sin_prefill = paddle.repeat_interleave(rot_sin, repeats=2, axis=-1)
|
||||||
|
|
||||||
def update_rotary_embs_decoder(self, forward_meta: ForwardMeta):
|
def update_rotary_embs_decoder(self, forward_meta: ForwardMeta):
|
||||||
if not self.enable_mm: # only initialize once for text-only model
|
if self.batch_ids_decode.shape[0] == 0:
|
||||||
if self.attention_metadata.rotary_cos_decode is None or self.attention_metadata.rotary_sin_decode is None:
|
return
|
||||||
self.attention_metadata.rotary_cos_decode = forward_meta.rotary_embs[0, 0, :, 0, :].astype(self.dtype)
|
|
||||||
self.attention_metadata.rotary_sin_decode = forward_meta.rotary_embs[1, 0, :, 0, :].astype(self.dtype)
|
bs = self.batch_ids_decode.shape[0]
|
||||||
elif self.batch_ids_decode.shape[0] > 0:
|
if self.enable_mm:
|
||||||
bs = self.batch_ids_decode.shape[0]
|
|
||||||
index = paddle.concat(
|
index = paddle.concat(
|
||||||
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
|
[self.batch_ids_decode.view([-1, 1]), self.seq_lens_dec.to("int64").view([-1, 1])], axis=1
|
||||||
)
|
)
|
||||||
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
rot_cos = paddle.gather_nd(forward_meta.rotary_embs[:, 0, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||||
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
rot_sin = paddle.gather_nd(forward_meta.rotary_embs[:, 1, 0, :, 0, :], index).view([bs, 1, 1, -1])
|
||||||
self.attention_metadata.rotary_cos_decode[:bs].copy_(paddle.repeat_interleave(rot_cos, repeats=2, axis=-1))
|
else:
|
||||||
self.attention_metadata.rotary_sin_decode[:bs].copy_(paddle.repeat_interleave(rot_sin, repeats=2, axis=-1))
|
rot_cos = paddle.gather(forward_meta.rotary_embs[0, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||||
|
rot_sin = paddle.gather(forward_meta.rotary_embs[1, 0, :, 0, :], self.seq_lens_dec).view([bs, 1, 1, -1])
|
||||||
|
self.attention_metadata.rotary_cos_decode[:bs].copy_(
|
||||||
|
paddle.repeat_interleave(rot_cos, repeats=2, axis=-1).astype(self.dtype)
|
||||||
|
)
|
||||||
|
self.attention_metadata.rotary_sin_decode[:bs].copy_(
|
||||||
|
paddle.repeat_interleave(rot_sin, repeats=2, axis=-1).astype(self.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
@@ -395,6 +406,25 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
}
|
}
|
||||||
# non last block: seq_lens_this_time > block_size
|
# non last block: seq_lens_this_time > block_size
|
||||||
else:
|
else:
|
||||||
|
if bool(self.num_layers_draft_model) and (
|
||||||
|
seq_len < self.block_size and i < cur_used_num_blocks - 1
|
||||||
|
):
|
||||||
|
cache_end = seq_len - cache_start
|
||||||
|
assert cache_end <= self.block_size
|
||||||
|
|
||||||
|
forward_meta.caches[k_cache_id][block_id, 0:cache_end, :, :] = slice_trans_k[
|
||||||
|
cache_start:seq_len, :, :
|
||||||
|
]
|
||||||
|
forward_meta.caches[v_cache_id][block_id, 0:cache_end, :, :] = slice_trans_v[
|
||||||
|
cache_start:seq_len, :, :
|
||||||
|
]
|
||||||
|
if layer_id == self.num_layers - 1:
|
||||||
|
self.record_block_table_metadata[batch_idx] = {
|
||||||
|
"block_id": block_id.item(),
|
||||||
|
"cache_end": cache_end,
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
assert seq_len > self.block_size
|
assert seq_len > self.block_size
|
||||||
cache_end = cache_start + self.block_size
|
cache_end = cache_start + self.block_size
|
||||||
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :]
|
forward_meta.caches[k_cache_id][block_id] = slice_trans_k[cache_start:cache_end, :, :]
|
||||||
@@ -403,9 +433,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
tensor_start = tensor_end
|
tensor_start = tensor_end
|
||||||
|
|
||||||
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
def forward_prefill(self, prefill_qkv, layer_id, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
qkv = prefill_qkv.view([-1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
q, k, v = cache_kv_with_rope(
|
||||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
prefill_qkv,
|
||||||
q, k = apply_rope(q, k, self.attention_metadata.rotary_cos_prefill, self.attention_metadata.rotary_sin_prefill)
|
forward_meta.caches[k_cache_id],
|
||||||
|
forward_meta.caches[v_cache_id],
|
||||||
|
self.block_table_prefill,
|
||||||
|
self.attention_metadata.rotary_cos_prefill,
|
||||||
|
self.attention_metadata.rotary_sin_prefill,
|
||||||
|
self.prefill_info_dict["cu_seqlens_q"],
|
||||||
|
self.batch_ids_q,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
prefill_out = flash_attn_unpadded_func(
|
prefill_out = flash_attn_unpadded_func(
|
||||||
q,
|
q,
|
||||||
@@ -419,23 +460,17 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
self.update_kv_cache(k, v, k_cache_id, v_cache_id, layer_id, forward_meta, self.batch_ids_prefill)
|
|
||||||
|
|
||||||
return prefill_out
|
return prefill_out
|
||||||
|
|
||||||
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
def forward_decode(self, decode_qkv, k_cache_id, v_cache_id, forward_meta: ForwardMeta):
|
||||||
qkv = decode_qkv.view([-1, 1, self.num_heads + self.kv_num_heads * 2, self.head_dim])
|
q, k, v = apply_rope_qkv(
|
||||||
q, k, v = qkv.split(num_or_sections=[self.num_heads, self.kv_num_heads, self.kv_num_heads], axis=-2)
|
decode_qkv,
|
||||||
|
self.attention_metadata.rotary_cos_decode,
|
||||||
if self.enable_mm: # vl
|
self.attention_metadata.rotary_sin_decode,
|
||||||
q, k = apply_rope(
|
self.num_heads,
|
||||||
q, k, self.attention_metadata.rotary_cos_decode, self.attention_metadata.rotary_sin_decode
|
self.kv_num_heads,
|
||||||
)
|
self.head_dim,
|
||||||
rotary_cos = None
|
)
|
||||||
rotary_sin = None
|
|
||||||
else:
|
|
||||||
rotary_cos = self.attention_metadata.rotary_cos_decode
|
|
||||||
rotary_sin = self.attention_metadata.rotary_sin_decode
|
|
||||||
|
|
||||||
decode_out = flash_attn_kvcache_func(
|
decode_out = flash_attn_kvcache_func(
|
||||||
q,
|
q,
|
||||||
@@ -445,8 +480,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.block_table_dec,
|
self.block_table_dec,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
rotary_cos=rotary_cos,
|
rotary_cos=None,
|
||||||
rotary_sin=rotary_sin,
|
rotary_sin=None,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
is_rotary_interleaved=True,
|
is_rotary_interleaved=True,
|
||||||
)[0].squeeze(1)
|
)[0].squeeze(1)
|
||||||
|
|||||||
@@ -209,7 +209,8 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
|||||||
None,
|
None,
|
||||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||||
"weight_only_int8",
|
expert_idx_per_token, # expert_idx_per_token: only for w4a8
|
||||||
|
self.moe_quant_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_ep_prefill(
|
def apply_ep_prefill(
|
||||||
@@ -262,15 +263,26 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
|||||||
permute_indices_per_token,
|
permute_indices_per_token,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_idx,
|
topk_idx,
|
||||||
|
expert_idx_per_token, # only for w4a8
|
||||||
) = moe_expert_dispatch(
|
) = moe_expert_dispatch(
|
||||||
x,
|
x,
|
||||||
gate_out,
|
gate_out,
|
||||||
|
None, # Use layer.gate_correction_bias in get_moe_scores.
|
||||||
|
None, # if set, permute_input will be int8_t
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
False,
|
False,
|
||||||
|
self.moe_quant_type,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, None)
|
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
|
||||||
|
# only w4a8 need expert_idx_per_token
|
||||||
|
# Other need not this tensor, so we make it None.
|
||||||
|
expert_idx_per_token = None
|
||||||
|
else:
|
||||||
|
expert_idx_per_token = expert_idx_per_token.cast("int64")
|
||||||
|
|
||||||
|
ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token)
|
||||||
|
|
||||||
fused_moe_out = moe_expert_reduce(
|
fused_moe_out = moe_expert_reduce(
|
||||||
ffn_out,
|
ffn_out,
|
||||||
@@ -291,7 +303,7 @@ class MetaxCutlassMoEMethod(MoEMethodBase):
|
|||||||
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
|
||||||
None,
|
None,
|
||||||
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
(layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None),
|
||||||
"weight_only_int8",
|
self.moe_quant_type,
|
||||||
layer.top_k,
|
layer.top_k,
|
||||||
True,
|
True,
|
||||||
False,
|
False,
|
||||||
|
|||||||
@@ -344,21 +344,12 @@ def post_process_normal(
|
|||||||
model_output.stop_flags,
|
model_output.stop_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu():
|
if (
|
||||||
set_stop_value_multi_ends(
|
current_platform.is_cuda()
|
||||||
sampler_output.sampled_token_ids,
|
or current_platform.is_iluvatar()
|
||||||
model_output.stop_flags,
|
or current_platform.is_dcu()
|
||||||
model_output.seq_lens_this_time,
|
or current_platform.is_maca()
|
||||||
model_output.eos_token_id,
|
):
|
||||||
model_output.next_tokens,
|
|
||||||
model_output.pre_ids,
|
|
||||||
model_output.step_idx,
|
|
||||||
model_output.stop_token_ids,
|
|
||||||
model_output.stop_seqs_len,
|
|
||||||
model_output.min_tokens,
|
|
||||||
False,
|
|
||||||
) # multi ends
|
|
||||||
elif current_platform.is_maca():
|
|
||||||
set_stop_value_multi_ends(
|
set_stop_value_multi_ends(
|
||||||
sampler_output.sampled_token_ids,
|
sampler_output.sampled_token_ids,
|
||||||
model_output.stop_flags,
|
model_output.stop_flags,
|
||||||
|
|||||||
@@ -1819,7 +1819,7 @@ class MetaxModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
post_process(
|
post_process(
|
||||||
sampler_output=sampler_output,
|
sampler_or_pooler_output=sampler_output,
|
||||||
model_output=model_output_data,
|
model_output=model_output_data,
|
||||||
share_inputs=self.share_inputs,
|
share_inputs=self.share_inputs,
|
||||||
block_size=self.cache_config.block_size,
|
block_size=self.cache_config.block_size,
|
||||||
|
|||||||
Reference in New Issue
Block a user