diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index 99ccc42bb..f5845bea9 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -48,14 +48,15 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; #define SAMPLING_CUB_SUBTRACTLEFT_DEFINED #endif -template struct Pair { +template +struct Pair { T value; int count; - __device__ Pair operator+(const Pair &other) const { + __device__ Pair operator+(const Pair& other) const { return {value + other.value, count + other.count}; } - __device__ Pair &operator+=(const Pair &other) { + __device__ Pair& operator+=(const Pair& other) { value += other.value; count += other.count; return *this; @@ -78,22 +79,25 @@ struct ValueCount { }; struct BoolDiffOp { - __device__ __forceinline__ bool operator()(const bool &lhs, - const bool &rhs) const { + __device__ __forceinline__ bool operator()(const bool& lhs, + const bool& rhs) const { return lhs != rhs; } }; -template struct SamplingTempStorage { union { float deterministic_scan[BLOCK_THREADS / 32]; typename BlockScan::TempStorage scan; - typename BlockReduce::TempStorage reduce; - typename BlockReduce::TempStorage reduce_int; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_value_count; + typename BlockReduce::TempStorage + reduce; + typename BlockReduce::TempStorage + reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>:: + TempStorage reduce_value_count; typename BlockAdjacentDifference::TempStorage adj_diff; } block_prim; struct { @@ -112,14 +116,17 @@ struct SamplingTempStorage { * algorithm. \note This implementation is slower than the cub::BlockScan, but * it is deterministic. */ -template -__device__ __forceinline__ void -DeterministicInclusiveSum(const T *in_data, T *out_data, - SamplingTempStorage *temp_storage) { - T *smem_prefix_sum = temp_storage->block_prim.deterministic_scan; + BlockReduceAlgorithm REDUCE_ALGORITHM, + typename T> +__device__ __forceinline__ void DeterministicInclusiveSum( + const T* in_data, + T* out_data, + SamplingTempStorage* + temp_storage) { + T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan; T thread_data[VEC_SIZE]; T thread_sum = 0; #pragma unroll @@ -138,8 +145,8 @@ DeterministicInclusiveSum(const T *in_data, T *out_data, } } - T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, - threadIdx.x | 0xffffffff); + T warp_sum = __shfl_sync( + 0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); if (threadIdx.x % 32 == 31) { thread_exclusive_prefix_sum = 0; } @@ -197,12 +204,21 @@ DeterministicInclusiveSum(const T *in_data, T *out_data, } } -template +template __device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, uint32_t d, Predicate pred, float u, vec_t prob_vec, + uint32_t i, + uint32_t d, + Predicate pred, + float u, + vec_t prob_vec, float& aggregate, - SamplingTempStorage* temp_storage) { + SamplingTempStorage* + temp_storage) { const uint32_t tx = threadIdx.x; float prob_greater_than_threshold[VEC_SIZE]; float inclusive_cdf[VEC_SIZE]; @@ -212,14 +228,14 @@ __device__ __forceinline__ void DeviceSamplingFromProb( prob_greater_than_threshold[j] = pred(prob_vec[j]) ? prob_vec[j] : 0; valid[j] = pred(prob_vec[j]) && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d; } -#ifdef PADDLE_WITH_COREX - float aggregate_local = - BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + float aggregate_local = BlockReduce( + temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); #else - float aggregate_local = - BlockReduce(temp_storage->block_prim.reduce) - .Sum(prob_greater_than_threshold); + float aggregate_local = BlockReduce( + temp_storage->block_prim.reduce) + .Sum(prob_greater_than_threshold); #endif if (tx == 0) { temp_storage->block_aggregate.value = aggregate_local; @@ -229,14 +245,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb( if (aggregate + aggregate_local > u) { if constexpr (DETERMINISTIC) { - DeterministicInclusiveSum( + DeterministicInclusiveSum( prob_greater_than_threshold, inclusive_cdf, temp_storage); } else { -#ifdef PADDLE_WITH_COREX - BlockScan(temp_storage->block_prim.scan) +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + BlockScan( + temp_storage->block_prim.scan) .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); #else - BlockScan(temp_storage->block_prim.scan) + BlockScan( + temp_storage->block_prim.scan) .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); #endif @@ -250,28 +271,35 @@ __device__ __forceinline__ void DeviceSamplingFromProb( bool greater_than_u_diff[VEC_SIZE]; #ifdef SAMPLING_CUB_SUBTRACTLEFT_DEFINED - #ifdef PADDLE_WITH_COREX - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); - #else - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); - #endif +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); #else - #ifdef PADDLE_WITH_COREX - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); - #else - BlockAdjacentDifference(temp_storage->block_prim.adj_diff) - .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); - #endif + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .SubtractLeft( + greater_than_u, greater_than_u_diff, BoolDiffOp()); +#endif +#else +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); +#else + BlockAdjacentDifference( + temp_storage->block_prim.adj_diff) + .FlagHeads( + greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); +#endif #endif __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { if (greater_than_u_diff[j]) { - atomicMin(&(temp_storage->sampled_id), (i * BLOCK_THREADS + tx) * VEC_SIZE + j); + atomicMin(&(temp_storage->sampled_id), + (i * BLOCK_THREADS + tx) * VEC_SIZE + j); } } __syncthreads(); @@ -287,9 +315,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb( valid_index[j] = -1; } } - int max_valid_index = - BlockReduce(temp_storage->block_prim.reduce_int) - .Reduce(valid_index, cub::Max()); + int max_valid_index = BlockReduce( + temp_storage->block_prim.reduce_int) + .Reduce(valid_index, cub::Max()); if (tx == 0 && max_valid_index != -1) { temp_storage->last_valid_id = max_valid_index; } @@ -297,15 +325,19 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } - - - -template -__global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, - float* top_p_arr, IdType* top_k_arr, - uint32_t d, uint64_t philox_seed, +template +__global__ void TopKTopPSamplingFromProbKernel(DType* probs, + IdType* output, + float* top_p_arr, + IdType* top_k_arr, + uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; @@ -315,12 +347,12 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx]; const float p = top_p_arr[row_idx]; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>( + smem_sampling); vec_t probs_vec; float aggregate; @@ -336,12 +368,22 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + i, + d, + [&](float x) { return x > low; }, + u, + probs_vec, + aggregate, + &temp_storage); if (aggregate > u) { break; } @@ -362,28 +404,29 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot_0[j] = { - (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - probs_gt_pivot_1[j] = { - (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_0[j] = {(probs_vec[j] > pivot_0) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_0 && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + probs_gt_pivot_1[j] = {(probs_vec[j] > pivot_1) ? probs_vec[j] : 0, + (probs_vec[j] > pivot_1 && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_0 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); #else - aggregate_gt_pivot_0 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0); + aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0); #endif if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_0; @@ -391,14 +434,14 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, __syncthreads(); aggregate_gt_pivot_0 = temp_storage.block_aggregate.pair; -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_1 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); #else - aggregate_gt_pivot_1 += - BlockReduce, BLOCK_THREADS>(temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1); + aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1); #endif if (tx == 0) { temp_storage.block_aggregate.pair = aggregate_gt_pivot_1; @@ -427,14 +470,19 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, } } - - -template -__global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, - float* top_p_arr, uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { +template +__global__ void TopPSamplingFromProbKernel(DType* probs, + IdType* output, + float* top_p_arr, + uint32_t d, + uint64_t philox_seed, + uint64_t philox_offset) { const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -442,12 +490,12 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, const uint32_t row_idx = bx; float top_p = top_p_arr[row_idx]; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>( + smem_sampling); vec_t probs_vec; float aggregate; @@ -463,12 +511,22 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - DeviceSamplingFromProb( - i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); + i, + d, + [&](float x) { return x > low; }, + u, + probs_vec, + aggregate, + &temp_storage); if (aggregate > u) { break; } @@ -489,7 +547,8 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; @@ -499,12 +558,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; } -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); #else - aggregate_gt_pivot_0 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_0); + aggregate_gt_pivot_0 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_0); #endif if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_0; @@ -512,12 +573,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, __syncthreads(); aggregate_gt_pivot_0 = temp_storage.block_aggregate.value; -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); #else - aggregate_gt_pivot_1 += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot_1); + aggregate_gt_pivot_1 += + BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot_1); #endif if (tx == 0) { temp_storage.block_aggregate.value = aggregate_gt_pivot_1; @@ -546,9 +609,13 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, } } -template -__device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, +__device__ __forceinline__ float GetMaxValue(float* in_data, + uint32_t row_idx, + uint32_t d, TempStorage& temp_storage) { const uint32_t tx = threadIdx.x; vec_t in_data_vec; @@ -557,21 +624,24 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { in_data_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + in_data_vec.cast_load(in_data + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } float in_data_[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { in_data_[j] = in_data_vec[j]; } -#ifdef PADDLE_WITH_COREX - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, cub::Max())); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + max_val = max(max_val, + BlockReduce( + temp_storage.block_prim.reduce) + .Reduce(in_data_, cub::Max())); #else - max_val = max( - max_val, BlockReduce(temp_storage.block_prim.reduce) - .Reduce(in_data_, cub::Max())); + max_val = max(max_val, + BlockReduce( + temp_storage.block_prim.reduce) + .Reduce(in_data_, cub::Max())); #endif __syncthreads(); } @@ -585,10 +655,12 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u template struct RenormTempStorage { union { - typename BlockReduce::TempStorage reduce; - typename BlockReduce::TempStorage reduce_int; - typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage - reduce_value_count; + typename BlockReduce::TempStorage + reduce; + typename BlockReduce::TempStorage + reduce_int; + typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>:: + TempStorage reduce_value_count; } block_prim; struct { float max_val; @@ -607,24 +679,33 @@ struct RenormTempStorage { }; }; -template -__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr, - DType* renormed_prob,uint32_t d) { +template +__global__ void MinPSamplingFromProbKernel(DType* probs, + const float* min_p_arr, + DType* renormed_prob, + uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx]; const uint32_t row_idx = bx; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; - auto& temp_storage = - reinterpret_cast&>( - smem_sampling); + auto& temp_storage = reinterpret_cast< + SamplingTempStorage&>( + smem_sampling); - float max_val = GetMaxValue>( + float max_val = GetMaxValue< + VEC_SIZE, + BLOCK_THREADS, + REDUCE_ALGORITHM, + SamplingTempStorage>( probs, row_idx, d, temp_storage); float pivot = max_val * p; @@ -633,7 +714,8 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr, for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + (i * BLOCK_THREADS + tx) * VEC_SIZE); } #pragma unroll @@ -641,42 +723,51 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr, probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; } if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.store(renormed_prob + row_idx * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - } } - -template -__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) { +template +__global__ void TopKRenormProbKernel(DType* probs, + DType* renormed_prob, + IdType* top_k_arr, + uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; const uint32_t k = top_k_arr[row_idx] == 0 ? d : top_k_arr[row_idx]; -#ifdef PADDLE_WITH_COREX +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) double pivot = std::numeric_limits::infinity(), normalizer = 1; #else double pivot = -cuda::std::numeric_limits::infinity(), normalizer = 1; #endif vec_t probs_vec; if (k < d) { - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; + extern __shared__ __align__(alignof( + RenormTempStorage)) uint8_t smem_renorm[]; auto& temp_storage = - reinterpret_cast&>(smem_renorm); + reinterpret_cast&>( + smem_renorm); temp_storage.max_val = 0; - float max_val = GetMaxValue>( - probs, row_idx, d, temp_storage); + float max_val = + GetMaxValue>( + probs, row_idx, d, temp_storage); double low = 0, high = max_val; float min_gt_low, max_le_high; float sum_low = 1; // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} - // loop invariant: + // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | + // p <= high} loop invariant: // - f(low) >= k, f(high) < k // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) // stopping condition: min_gt_low == max_le_high @@ -692,55 +783,65 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; + ValueCount probs_gt_pivot_0_pair[VEC_SIZE], + probs_gt_pivot_1_pair[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_gt_pivot_0_pair[j] = { (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + (probs_vec[j] > pivot_0 && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; probs_gt_pivot_1_pair[j] = { (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, - (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + (probs_vec[j] > pivot_1 && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; - if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + if (probs_vec[j] > low && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { min_gt_low = min(min_gt_low, probs_vec[j]); } - if (probs_vec[j] <= high && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { + if (probs_vec[j] <= high && + (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { max_le_high = max(max_le_high, probs_vec[j]); } } -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0_pair); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); #else - aggregate_gt_pivot_0 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_0_pair); + aggregate_gt_pivot_0 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_0_pair); #endif __syncthreads(); -#ifdef PADDLE_WITH_COREX - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1_pair); +#if defined(PADDLE_WITH_COREX) || defined(PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU) + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); #else - aggregate_gt_pivot_1 += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( - temp_storage.block_prim.reduce_value_count) - .Sum(probs_gt_pivot_1_pair); + aggregate_gt_pivot_1 += + BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_value_count) + .Sum(probs_gt_pivot_1_pair); #endif __syncthreads(); } - min_gt_low = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(min_gt_low, cub::Min()); + min_gt_low = BlockReduce( + temp_storage.block_prim.reduce) + .Reduce(min_gt_low, cub::Min()); __syncthreads(); - max_le_high = - BlockReduce(temp_storage.block_prim.reduce) - .Reduce(max_le_high, cub::Max()); + max_le_high = BlockReduce( + temp_storage.block_prim.reduce) + .Reduce(max_le_high, cub::Max()); if (tx == 0) { temp_storage.block_aggregate.pairs[0] = aggregate_gt_pivot_0; temp_storage.block_aggregate.pairs[1] = aggregate_gt_pivot_1; @@ -774,23 +875,29 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + + tx * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_vec[j] = (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : 0; } if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + probs_vec.store(renormed_prob + row_idx * d + + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } } } template -cudaError_t TopPSamplingFromProb(T *probs, IdType *output, - uint32_t batch_size, const T *top_p_val, - uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, +cudaError_t TopPSamplingFromProb(T* probs, + IdType* output, + uint32_t batch_size, + const T* top_p_val, + uint32_t d, + bool deterministic, + uint64_t philox_seed, + uint64_t philox_offset, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -799,99 +906,139 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &top_p_val, - &d, &philox_seed, &philox_offset}; + void* args[] = { + &probs, &output, &top_p_val, &d, &philox_seed, &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, + vec_size, + VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - TopPSamplingFromProbKernel; + auto kernel = TopPSamplingFromProbKernel; CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args, - smem_size, stream)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); return cudaSuccess; } -template -cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob, +template +cudaError_t MinPSamplingFromProb(T* probs, + const T* min_p_arr, + T* renormed_prob, uint32_t batch_size, - uint32_t d, bool deterministic, - cudaStream_t stream = 0){ + uint32_t d, + bool deterministic, + cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr,&renormed_prob,&d}; + void* args[] = {&probs, &min_p_arr, &renormed_prob, &d}; DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, + vec_size, + VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = - MinPSamplingFromProbKernel; + auto kernel = MinPSamplingFromProbKernel; CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args, - smem_size, stream)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); return cudaSuccess; } - template -cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, - uint32_t batch_size, const T *top_p_val, const IdType *top_k_val, - uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, +cudaError_t TopKTopPSamplingFromProb(T* probs, + IdType* output, + uint32_t batch_size, + const T* top_p_val, + const IdType* top_k_val, + uint32_t d, + bool deterministic, + uint64_t philox_seed, + uint64_t philox_offset, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(T), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &top_p_val, &top_k_val, - &d, &philox_seed, &philox_offset}; + void* args[] = {&probs, + &output, + &top_p_val, + &top_k_val, + &d, + &philox_seed, + &philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; - CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + vec_size, + VEC_SIZE, + {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKTopPSamplingFromProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); return cudaSuccess; }); } template -cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, - uint32_t batch_size, uint32_t d, +cudaError_t TopKRenormProb(DType* probs, + DType* renormed_prob, + IdType* top_k_arr, + uint32_t batch_size, + uint32_t d, cudaStream_t stream = 0) { const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { - const uint32_t smem_size = sizeof(RenormTempStorage); + const uint32_t smem_size = + sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &renormed_prob, &top_k_arr, &d}; DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + auto kernel = TopKRenormProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel( + (void*)kernel, nblks, nthrs, args, smem_size, stream)); }); return cudaSuccess; }); } -} // namespace sampling +} // namespace sampling diff --git a/custom_ops/gpu_ops/sample_kernels/utils.cuh b/custom_ops/gpu_ops/sample_kernels/utils.cuh index 1de480ab8..3488eb42b 100644 --- a/custom_ops/gpu_ops/sample_kernels/utils.cuh +++ b/custom_ops/gpu_ops/sample_kernels/utils.cuh @@ -23,221 +23,235 @@ #include #include +#include +#include +#include #include #include #include #include #include -#include -#include -#include /******************* utils *******************/ #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) #ifndef NDEBUG -#define CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ - << ") " << __FILE__ << ": line " << __LINE__ \ - << " at function " << STR(func) << std::endl; \ - return e; \ - } \ +#define CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ } #else -#define CUDA_CALL(func, ...) \ - { \ - cudaError_t e = (func); \ - if (e != cudaSuccess) { \ - return e; \ - } \ +#define CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + return e; \ + } \ } #endif -#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ - if (deterministic) { \ - constexpr bool DETERMINISTIC = true; \ - __VA_ARGS__ \ - } else { \ - constexpr bool DETERMINISTIC = false; \ - __VA_ARGS__ \ +#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ + if (deterministic) { \ + constexpr bool DETERMINISTIC = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool DETERMINISTIC = false; \ + __VA_ARGS__ \ } -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::invalid_argument(err_msg.str()); \ - } \ +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + throw std::invalid_argument(err_msg.str()); \ + } \ } /******************* vec_t *******************/ #define SAMPLING_INLINE inline __attribute__((always_inline)) __device__ -template struct vec_t { - SAMPLING_INLINE float_t &operator[](size_t i); - SAMPLING_INLINE const float_t &operator[](size_t i) const; +template +struct vec_t { + SAMPLING_INLINE float_t& operator[](size_t i); + SAMPLING_INLINE const float_t& operator[](size_t i) const; SAMPLING_INLINE void fill(float_t val); - SAMPLING_INLINE void load(const float_t *ptr); - SAMPLING_INLINE void store(float_t *ptr) const; + SAMPLING_INLINE void load(const float_t* ptr); + SAMPLING_INLINE void store(float_t* ptr) const; template - SAMPLING_INLINE void cast_from(const vec_t &src); - template SAMPLING_INLINE void cast_load(const T *ptr); - template SAMPLING_INLINE void cast_store(T *ptr) const; - SAMPLING_INLINE static void memcpy(float_t *dst, const float_t *src); - SAMPLING_INLINE float_t *ptr(); + SAMPLING_INLINE void cast_from(const vec_t& src); + template + SAMPLING_INLINE void cast_load(const T* ptr); + template + SAMPLING_INLINE void cast_store(T* ptr) const; + SAMPLING_INLINE static void memcpy(float_t* dst, const float_t* src); + SAMPLING_INLINE float_t* ptr(); }; // float x 1 -template <> struct vec_t { +template <> +struct vec_t { float data; - SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; } - SAMPLING_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; } - SAMPLING_INLINE float *ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val); - SAMPLING_INLINE void load(const float *ptr); - SAMPLING_INLINE void store(float *ptr) const; - template SAMPLING_INLINE void cast_from(const vec_t &src) { + SAMPLING_INLINE void load(const float* ptr); + SAMPLING_INLINE void store(float* ptr) const; + template + SAMPLING_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } - template SAMPLING_INLINE void cast_load(const T *ptr) { + template + SAMPLING_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } - template SAMPLING_INLINE void cast_store(T *ptr) const { + template + SAMPLING_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } - SAMPLING_INLINE static void memcpy(float *dst, const float *src); + SAMPLING_INLINE static void memcpy(float* dst, const float* src); }; SAMPLING_INLINE void vec_t::fill(float val) { data = val; } -SAMPLING_INLINE void vec_t::load(const float *ptr) { data = *ptr; } +SAMPLING_INLINE void vec_t::load(const float* ptr) { data = *ptr; } -SAMPLING_INLINE void vec_t::store(float *ptr) const { *ptr = data; } +SAMPLING_INLINE void vec_t::store(float* ptr) const { *ptr = data; } -SAMPLING_INLINE void vec_t::memcpy(float *dst, const float *src) { +SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } // float x 2 -template <> struct vec_t { +template <> +struct vec_t { float2 data; - SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(&data))[i]; } - SAMPLING_INLINE const float &operator[](size_t i) const { - return ((const float *)(&data))[i]; + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(&data))[i]; } - SAMPLING_INLINE float *ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val); - SAMPLING_INLINE void load(const float *ptr); - SAMPLING_INLINE void store(float *ptr) const; - template SAMPLING_INLINE void cast_from(const vec_t &src) { + SAMPLING_INLINE void load(const float* ptr); + SAMPLING_INLINE void store(float* ptr) const; + template + SAMPLING_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } - template SAMPLING_INLINE void cast_load(const T *ptr) { + template + SAMPLING_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } - template SAMPLING_INLINE void cast_store(T *ptr) const { + template + SAMPLING_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } - SAMPLING_INLINE static void memcpy(float *dst, const float *src); + SAMPLING_INLINE static void memcpy(float* dst, const float* src); }; SAMPLING_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } -SAMPLING_INLINE void vec_t::load(const float *ptr) { - data = *((float2 *)ptr); +SAMPLING_INLINE void vec_t::load(const float* ptr) { + data = *((float2*)ptr); } -SAMPLING_INLINE void vec_t::store(float *ptr) const { - *((float2 *)ptr) = data; +SAMPLING_INLINE void vec_t::store(float* ptr) const { + *((float2*)ptr) = data; } -SAMPLING_INLINE void vec_t::memcpy(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); +SAMPLING_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); } // float x 4 or more -template struct vec_t { +template +struct vec_t { float4 data[vec_size / 4]; - SAMPLING_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } - SAMPLING_INLINE const float &operator[](size_t i) const { - return ((const float *)(data))[i]; + SAMPLING_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + SAMPLING_INLINE const float& operator[](size_t i) const { + return ((const float*)(data))[i]; } - SAMPLING_INLINE float *ptr() { return reinterpret_cast(&data); } + SAMPLING_INLINE float* ptr() { return reinterpret_cast(&data); } SAMPLING_INLINE void fill(float val) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { data[i] = make_float4(val, val, val, val); } } - SAMPLING_INLINE void load(const float *ptr) { + SAMPLING_INLINE void load(const float* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; + data[i] = ((float4*)ptr)[i]; } } - SAMPLING_INLINE void store(float *ptr) const { + SAMPLING_INLINE void store(float* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; + ((float4*)ptr)[i] = data[i]; } } template - SAMPLING_INLINE void cast_from(const vec_t &src) { + SAMPLING_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } - template SAMPLING_INLINE void cast_load(const T *ptr) { + template + SAMPLING_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } - template SAMPLING_INLINE void cast_store(T *ptr) const { + template + SAMPLING_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } - SAMPLING_INLINE static void memcpy(float *dst, const float *src) { + SAMPLING_INLINE static void memcpy(float* dst, const float* src) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; + ((float4*)dst)[i] = ((float4*)src)[i]; } } }; template SAMPLING_INLINE void cast_load_impl(vec_t& dst, - const src_float_t* src_ptr) { + const src_float_t* src_ptr) { if constexpr (std::is_same_v) { dst.load(src_ptr); } else { @@ -260,11 +274,16 @@ inline std::pair GetCudaComputeCapability() { __forceinline__ __device__ float ptx_rcp(float x) { #ifdef PADDLE_WITH_COREX return __ivcorex_rcpf(x); +#else +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + return __frcp_rn(x); #else float y; asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; #endif + +#endif } template diff --git a/custom_ops/metax_ops/apply_rope.cu b/custom_ops/metax_ops/apply_rope.cu deleted file mode 100644 index 4e820e425..000000000 --- a/custom_ops/metax_ops/apply_rope.cu +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include "helper.h" - -#define THREADS_PER_BLOCK 128 - -template -struct Converter; - -template <> -struct Converter<__half> { - // __half -> float - __device__ static float to_float(__half val) { return __half2float(val); } - // float -> __half - __device__ static __half from_float(float val) { - return __float2half_rn(val); - } - // int -> __half - __device__ static __half from_int(float val) { return __int2half_rn(val); } -}; - -template <> -struct Converter<__nv_bfloat16> { - // __nv_bfloat16 -> float - __device__ static float to_float(__nv_bfloat16 val) { - return __bfloat162float(val); - } - // float -> __nv_bfloat16 - __device__ static __nv_bfloat16 from_float(float val) { - return __float2bfloat16_rn(val); - } - // int -> __nv_bfloat16 - __device__ static __nv_bfloat16 from_int(int val) { - return __int2bfloat16_rn(val); - } -}; - -template -__device__ void RotateQKVec4(const T* qk_ptr, - const T* rot_cos_ptr, - const T* rot_sin_ptr, - const int head_num, - const int base_idx, - const int rot_base_idx, - T* out) { - using VecT = AlignedVector; - - VecT qk_vec; - Load(qk_ptr + base_idx, &qk_vec); - VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]}; - VecT cos_vec, sin_vec; - Load(rot_cos_ptr + rot_base_idx, &cos_vec); - Load(rot_sin_ptr + rot_base_idx, &sin_vec); -#pragma unroll - for (int i = 0; i < 4; ++i) { - *(out + base_idx + i) = - qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; - } -} - -template -__device__ void RotateQKVec4(const T* qk_ptr, - const float* rot_cos_ptr, - const float* rot_sin_ptr, - const int head_num, - const int base_idx, - const int rot_base_idx, - T* out) { - using VecT = AlignedVector; - using VecF = AlignedVector; - auto to_float = [] __device__(T val) -> float { - return Converter::to_float(val); - }; - auto from_float = [] __device__(float val) -> T { - return Converter::from_float(val); - }; - - VecT qk_vec; - Load(qk_ptr + base_idx, &qk_vec); - VecF rot_half_vec = {-to_float(qk_vec[1]), - to_float(qk_vec[0]), - -to_float(qk_vec[3]), - to_float(qk_vec[2])}; - VecF cos_vec, sin_vec; - Load(rot_cos_ptr + rot_base_idx, &cos_vec); - Load(rot_sin_ptr + rot_base_idx, &sin_vec); -#pragma unroll - for (int i = 0; i < 4; ++i) { - *(out + base_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] + - rot_half_vec[i] * sin_vec[i]); - } -} - -// qk and rope have a same type -template -__global__ void DispatchApplyRopeVec4Kernel(const T* q, - const T* k, - const T* rot_cos, - const T* rot_sin, - const int q_num_elements, - const int k_num_elements, - const int q_head_num, - const int k_head_num, - const int head_dim, - T* q_out, - T* k_out) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - int head_dim_idx = idx % head_dim; - - if (idx < q_num_elements) { - int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; - RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); - } - - if (idx < k_num_elements) { - int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; - RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); - } -} - -// rope dtype is float32 -template -__global__ void DispatchApplyRopeVec4Kernel(const T* q, - const T* k, - const float* rot_cos, - const float* rot_sin, - const int q_num_elements, - const int k_num_elements, - const int q_head_num, - const int k_head_num, - const int head_dim, - T* q_out, - T* k_out) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - int head_dim_idx = idx % head_dim; - - if (idx < q_num_elements) { - int rot_idx = idx / (q_head_num * head_dim) * head_dim + head_dim_idx; - RotateQKVec4(q, rot_cos, rot_sin, q_head_num, idx, rot_idx, q_out); - } - - if (idx < k_num_elements) { - int rot_idx = idx / (k_head_num * head_dim) * head_dim + head_dim_idx; - RotateQKVec4(k, rot_cos, rot_sin, k_head_num, idx, rot_idx, k_out); - } -} - -template -void ApplyRopeKernel(const paddle::Tensor& q, - const paddle::Tensor& k, - const paddle::Tensor& rot_cos, - const paddle::Tensor& rot_sin, - paddle::Tensor& q_out, - paddle::Tensor& k_out) { - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - - const auto q_num_elements = q.numel(); - const auto k_num_elements = k.numel(); - const auto q_shape = q.shape(); - const auto k_shape = k.shape(); - const auto dims = q_shape.size(); - const auto q_head_num = q_shape[dims - 2]; - const auto k_head_num = k_shape[dims - 2]; - const auto head_dim = q_shape.back(); - int block_num = - (std::max(q_num_elements, k_num_elements) + (THREADS_PER_BLOCK * 4) - 1) / - (THREADS_PER_BLOCK * 4); - auto stream = q.stream(); - - if (q.dtype() == rot_cos.dtype()) { - DispatchApplyRopeVec4Kernel - <<>>( - reinterpret_cast(q.data()), - reinterpret_cast(k.data()), - reinterpret_cast(rot_cos.data()), - reinterpret_cast(rot_sin.data()), - q_num_elements, - k_num_elements, - q_head_num, - k_head_num, - head_dim, - reinterpret_cast(q_out.data()), - reinterpret_cast(k_out.data())); - } else if (rot_cos.dtype() == paddle::DataType::FLOAT32) { - DispatchApplyRopeVec4Kernel - <<>>( - reinterpret_cast(q.data()), - reinterpret_cast(k.data()), - reinterpret_cast(rot_cos.data()), - reinterpret_cast(rot_sin.data()), - q_num_elements, - k_num_elements, - q_head_num, - k_head_num, - head_dim, - reinterpret_cast(q_out.data()), - reinterpret_cast(k_out.data())); - } else { - PD_THROW("Unsupported qk dtype and rope dtype."); - } -} - -std::vector ApplyRope(const paddle::Tensor& q, - const paddle::Tensor& k, - const paddle::Tensor& rot_cos, - const paddle::Tensor& rot_sin) { - auto q_shape = q.shape(); - auto cos_shape = rot_cos.shape(); - - auto q_out = paddle::empty_like(q); - auto k_out = paddle::empty_like(k); - - if (q.numel() == 0 || k.numel() == 0) { - return {q_out, k_out}; - } - - PADDLE_ENFORCE_EQ( - q_shape.back() % 2, - 0, - "The last dimension (head_dim) of qk must be an even number " - "for RoPE, but got %d", - q_shape.back()); - PADDLE_ENFORCE_EQ(q_shape.size(), - cos_shape.size(), - "The shape size of cos mismatches the shape size of q, " - "expect %d but got %d", - q_shape.size(), - cos_shape.size()); - PADDLE_ENFORCE_EQ(q_shape.back(), - cos_shape.back(), - "The shape.back() of cos mismatches the shape.back() of q, " - "expect %d but got %d", - q_shape.back(), - cos_shape.back()); - - auto input_type = q.dtype(); - switch (input_type) { - case paddle::DataType::BFLOAT16: - ApplyRopeKernel( - q, k, rot_cos, rot_sin, q_out, k_out); - break; - case paddle::DataType::FLOAT16: - ApplyRopeKernel( - q, k, rot_cos, rot_sin, q_out, k_out); - break; - default: - PD_THROW("Only support qk dtype of BF16 and F16"); - } - - return {q_out, k_out}; -} - -std::vector> ApplyRopeInferShape( - const std::vector& q_shape, - const std::vector& k_shape, - const std::vector& cos_shape, - const std::vector& sin_shape) { - return {q_shape, k_shape, cos_shape, sin_shape}; -} - -std::vector ApplyRopeInferDtype( - const paddle::DataType& q_dtype, - const paddle::DataType& k_dtype, - const paddle::DataType& cos_dtype, - const paddle::DataType& sin_dtype) { - return {q_dtype, k_dtype, cos_dtype, sin_dtype}; -} - -PD_BUILD_OP(apply_rope) - .Inputs({"q", "k", "rot_cos", "rot_sin"}) - .Outputs({"q_out", "k_out"}) - .SetKernelFn(PD_KERNEL(ApplyRope)) - .SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeInferDtype)); diff --git a/custom_ops/metax_ops/apply_rope_qkv.cu b/custom_ops/metax_ops/apply_rope_qkv.cu new file mode 100644 index 000000000..3c7679e07 --- /dev/null +++ b/custom_ops/metax_ops/apply_rope_qkv.cu @@ -0,0 +1,329 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "helper.h" + +template +struct Converter; + +template <> +struct Converter<__half> { + // __half -> float + __device__ static float to_float(__half val) { return __half2float(val); } + // float -> __half + __device__ static __half from_float(float val) { + return __float2half_rn(val); + } + // int -> __half + __device__ static __half from_int(float val) { return __int2half_rn(val); } +}; + +template <> +struct Converter<__nv_bfloat16> { + // __nv_bfloat16 -> float + __device__ static float to_float(__nv_bfloat16 val) { + return __bfloat162float(val); + } + // float -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_float(float val) { + return __float2bfloat16_rn(val); + } + // int -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_int(int val) { + return __int2bfloat16_rn(val); + } +}; + +struct ApplyRopeQKVParams { + int head_dim; + int token_stride; + int head_stride; + int q_stride; + int kv_stride; + int q_head_offset; + int k_head_offset; + int v_head_offset; + int q_head_num; + int kv_head_num; +}; + +template +__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr, + const T* rot_cos_ptr, + const T* rot_sin_ptr, + const int load_idx, + const int store_idx, + const int rot_base_idx, + T* out) { + using VecT = AlignedVector; + + VecT qk_vec; + Load(qkv_ptr + load_idx, &qk_vec); + VecT rot_half_vec = {-qk_vec[1], qk_vec[0], -qk_vec[3], qk_vec[2]}; + VecT cos_vec, sin_vec; + Load(rot_cos_ptr + rot_base_idx, &cos_vec); + Load(rot_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < 4; ++i) { + *(out + store_idx + i) = + qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; + } +} + +template +__device__ __forceinline__ void RotateQKVec4(const T* qkv_ptr, + const float* rot_cos_ptr, + const float* rot_sin_ptr, + const int load_idx, + const int store_idx, + const int rot_base_idx, + T* out) { + using VecT = AlignedVector; + using VecF = AlignedVector; + auto to_float = [] __device__(T val) -> float { + return Converter::to_float(val); + }; + auto from_float = [] __device__(float val) -> T { + return Converter::from_float(val); + }; + + VecT qk_vec; + Load(qkv_ptr + load_idx, &qk_vec); + VecF rot_half_vec = {-to_float(qk_vec[1]), + to_float(qk_vec[0]), + -to_float(qk_vec[3]), + to_float(qk_vec[2])}; + VecF cos_vec, sin_vec; + Load(rot_cos_ptr + rot_base_idx, &cos_vec); + Load(rot_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < 4; ++i) { + *(out + store_idx + i) = from_float(to_float(qk_vec[i]) * cos_vec[i] + + rot_half_vec[i] * sin_vec[i]); + } +} + +template +__device__ __forceinline__ void StoreValue(const T* qkv_ptr, + const int load_idx, + const int store_idx, + T* out) { + using VecT = AlignedVector; + VecT v_vec; + Load(qkv_ptr + load_idx, &v_vec); + Store(v_vec, out + store_idx); +} + +template +__global__ void DispatchApplyRopeQKVVec4Kernel(const T* qkv, + const WeightType* rot_cos, + const WeightType* rot_sin, + ApplyRopeQKVParams param, + T* q_out, + T* k_out, + T* v_out) { + const int token_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int head_idx = blockIdx.y * blockDim.y + threadIdx.y; + const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * 4; + int rot_idx = token_idx * param.head_dim + head_dim_idx; + int load_idx, store_idx; + + if (head_idx < param.q_head_num && head_dim_idx < param.head_dim) { // q + load_idx = token_idx * param.token_stride + + (head_idx + param.q_head_offset) * param.head_stride + + head_dim_idx; + store_idx = + token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx; + RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, q_out); + } + + if (head_idx < param.kv_head_num && head_dim_idx < param.head_dim) { // kv + load_idx = token_idx * param.token_stride + + (head_idx + param.k_head_offset) * param.head_stride + + head_dim_idx; + store_idx = + token_idx * param.kv_stride + head_idx * param.head_dim + head_dim_idx; + RotateQKVec4(qkv, rot_cos, rot_sin, load_idx, store_idx, rot_idx, k_out); + load_idx = token_idx * param.token_stride + + (head_idx + param.v_head_offset) * param.head_stride + + head_dim_idx; + StoreValue(qkv, load_idx, store_idx, v_out); + } +} + +template +void ApplyRopeQKVKernel(const paddle::Tensor& qkv, + const paddle::Tensor& rot_cos, + const paddle::Tensor& rot_sin, + const int q_head_num, + const int kv_head_num, + const int head_dim, + paddle::Tensor& q_out, + paddle::Tensor& k_out, + paddle::Tensor& v_out) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const int all_num_elements = qkv.numel(); + const int all_num_head = q_head_num + 2 * kv_head_num; + auto stream = qkv.stream(); + + dim3 block_dims(1, 4, 32); + dim3 grid_dims(all_num_elements / (all_num_head * head_dim), // token + (std::max(q_head_num, kv_head_num) + block_dims.y - 1) / + block_dims.y, // head + (head_dim + (block_dims.z * 4) - 1) / + (block_dims.z * 4) // dim: load vec4 at a time + ); + + // printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z); + // printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z); + + ApplyRopeQKVParams param; + param.head_dim = head_dim; + param.token_stride = all_num_head * head_dim; + param.head_stride = head_dim; + param.q_stride = q_head_num * head_dim; + param.kv_stride = kv_head_num * head_dim; + param.q_head_offset = 0; + param.k_head_offset = q_head_num; + param.v_head_offset = q_head_num + kv_head_num; + param.q_head_num = q_head_num; + param.kv_head_num = kv_head_num; + + if (qkv.dtype() == rot_cos.dtype()) { + DispatchApplyRopeQKVVec4Kernel + <<>>( + reinterpret_cast(qkv.data()), + reinterpret_cast(rot_cos.data()), + reinterpret_cast(rot_sin.data()), + param, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data()), + reinterpret_cast(v_out.data())); + } else if (rot_cos.dtype() == paddle::DataType::FLOAT32) { + DispatchApplyRopeQKVVec4Kernel + <<>>( + reinterpret_cast(qkv.data()), + reinterpret_cast(rot_cos.data()), + reinterpret_cast(rot_sin.data()), + param, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data()), + reinterpret_cast(v_out.data())); + } else { + PD_THROW("Unsupported qk dtype and rope dtype."); + } +} + +std::vector ApplyRopeQKV(const paddle::Tensor& qkv, + const paddle::Tensor& rot_cos, + const paddle::Tensor& rot_sin, + const int q_head_num, + const int kv_head_num, + const int head_dim) { + auto qkv_shape = qkv.shape(); + auto token_num = qkv_shape[0]; + auto place = qkv.place(); + auto dtype = qkv.dtype(); + common::DDim q_out_shape, kv_out_shape; + if (rot_cos.shape().size() == 3) { + q_out_shape = {token_num, q_head_num, head_dim}; + kv_out_shape = {token_num, kv_head_num, head_dim}; + } else { + q_out_shape = {token_num, 1, q_head_num, head_dim}; + kv_out_shape = {token_num, 1, kv_head_num, head_dim}; + } + auto q_out = GetEmptyTensor(q_out_shape, dtype, place); + auto k_out = GetEmptyTensor(kv_out_shape, dtype, place); + auto v_out = GetEmptyTensor(kv_out_shape, dtype, place); + + if (token_num == 0) { + return {q_out, k_out, v_out}; + } + + PADDLE_ENFORCE_EQ(qkv_shape.back(), + ((q_head_num + 2 * kv_head_num) * head_dim), + "The last dimension of qkv [%d] must equal to {(q_head_num " + "+ 2 * kv_head_num) * head_dim [%d].", + qkv_shape.back(), + ((q_head_num + 2 * kv_head_num) * head_dim)); + PADDLE_ENFORCE_EQ( + head_dim % 2, + 0, + "The last dimension (head_dim) of qkv must be an even number " + "for RoPE, but got %d", + head_dim); + PADDLE_ENFORCE_EQ(q_out.shape().back(), + rot_cos.shape().back(), + "The last dimension of cos mismatches that of q, " + "expect %d but got %d", + q_out.shape().back(), + rot_cos.shape().back()); + + switch (dtype) { + case paddle::DataType::BFLOAT16: + ApplyRopeQKVKernel(qkv, + rot_cos, + rot_sin, + q_head_num, + kv_head_num, + head_dim, + q_out, + k_out, + v_out); + break; + case paddle::DataType::FLOAT16: + ApplyRopeQKVKernel(qkv, + rot_cos, + rot_sin, + q_head_num, + kv_head_num, + head_dim, + q_out, + k_out, + v_out); + break; + default: + PD_THROW("Only support qk dtype of BF16 and F16"); + } + + return {q_out, k_out, v_out}; +} + +std::vector> ApplyRopeQKVInferShape( + const std::vector& qkv_shape, + const std::vector& cos_shape, + const std::vector& sin_shape) { + return {qkv_shape, cos_shape, sin_shape}; +} + +std::vector ApplyRopeQKVInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& cos_dtype, + const paddle::DataType& sin_dtype) { + return {qkv_dtype, cos_dtype, sin_dtype}; +} + +PD_BUILD_OP(apply_rope_qkv) + .Inputs({"qkv", "rot_cos", "rot_sin"}) + .Outputs({"q_out", "k_out", "v_out"}) + .Attrs({"q_head_num:int", "kv_head_num:int", "head_dim:int"}) + .SetKernelFn(PD_KERNEL(ApplyRopeQKV)) + .SetInferShapeFn(PD_INFER_SHAPE(ApplyRopeQKVInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ApplyRopeQKVInferDtype)); diff --git a/custom_ops/metax_ops/cache_kv_with_rope.cu b/custom_ops/metax_ops/cache_kv_with_rope.cu new file mode 100644 index 000000000..0f3e9a54e --- /dev/null +++ b/custom_ops/metax_ops/cache_kv_with_rope.cu @@ -0,0 +1,477 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "helper.h" + +template +struct Converter; + +template <> +struct Converter<__half> { + // __half -> float + __device__ static float to_float(__half val) { return __half2float(val); } + // float -> __half + __device__ static __half from_float(float val) { + return __float2half_rn(val); + } + // int -> __half + __device__ static __half from_int(float val) { return __int2half_rn(val); } +}; + +template <> +struct Converter<__nv_bfloat16> { + // __nv_bfloat16 -> float + __device__ static float to_float(__nv_bfloat16 val) { + return __bfloat162float(val); + } + // float -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_float(float val) { + return __float2bfloat16_rn(val); + } + // int -> __nv_bfloat16 + __device__ static __nv_bfloat16 from_int(int val) { + return __int2bfloat16_rn(val); + } +}; + +struct CacheKVWithRopeParams { + int head_dim; + int block_size; + int block_num; + int cache_stride; + int token_stride; + int head_stride; + int q_stride; + int kv_stride; + int q_head_offset; + int k_head_offset; + int v_head_offset; + int q_head_num; + int kv_head_num; +}; + +template +__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr, + const T* rotary_cos_ptr, + const T* rotary_sin_ptr, + const int load_idx, + const int store_idx, + const int cache_store_idx, + const int rot_base_idx, + T* caches, + T* out) { + using VecT = AlignedVector; + + VecT qk_vec; + Load(qkv_ptr + load_idx, &qk_vec); + VecT rot_half_vec; + int flag; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + flag = 1 - 2 * (i % 2); + rot_half_vec[i] = -qk_vec[i + flag] * Converter::from_int(flag); + } + VecT cos_vec, sin_vec; + Load(rotary_cos_ptr + rot_base_idx, &cos_vec); + Load(rotary_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + T result = qk_vec[i] * cos_vec[i] + rot_half_vec[i] * sin_vec[i]; + *(out + store_idx + i) = result; + + if (WriteCache) { + *(caches + cache_store_idx + i) = result; + } + } +} + +template +__device__ __forceinline__ void RotateQKVec(const T* qkv_ptr, + const float* rotary_cos_ptr, + const float* rotary_sin_ptr, + const int load_idx, + const int store_idx, + const int cache_store_idx, + const int rot_base_idx, + T* caches, + T* out) { + using VecT = AlignedVector; + using VecF = AlignedVector; + auto to_float = [] __device__(T val) -> float { + return Converter::to_float(val); + }; + auto from_float = [] __device__(float val) -> T { + return Converter::from_float(val); + }; + + VecT qk_vec; + Load(qkv_ptr + load_idx, &qk_vec); + VecF rot_half_vec; + int flag; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + flag = 1 - 2 * (i % 2); + rot_half_vec[i] = -to_float(qk_vec[i + flag]) * static_cast(flag); + } + VecF cos_vec, sin_vec; + Load(rotary_cos_ptr + rot_base_idx, &cos_vec); + Load(rotary_sin_ptr + rot_base_idx, &sin_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + T result = from_float(to_float(qk_vec[i]) * cos_vec[i] + + rot_half_vec[i] * sin_vec[i]); + *(out + store_idx + i) = result; + if (WriteCache) { + *(caches + cache_store_idx + i) = result; + } + } +} + +template +__device__ __forceinline__ void StoreValue(const T* qkv_ptr, + const int load_idx, + const int store_idx, + const int cache_store_idx, + T* caches, + T* out) { + using VecT = AlignedVector; + VecT v_vec; + Load(qkv_ptr + load_idx, &v_vec); + Store(v_vec, out + store_idx); + Store(v_vec, caches + cache_store_idx); +} + +template +__global__ void DispatchCacheKVWithRopeVecKernel(const T* qkv, + T* caches_k, + T* caches_v, + const int* block_tables, + const WeightType* rotary_cos, + const WeightType* rotary_sin, + const int* cu_seqlens_q, + const int* batch_ids_q, + CacheKVWithRopeParams param, + T* q_out, + T* k_out, + T* v_out) { + const int token_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int head_idx = blockIdx.y * blockDim.y + threadIdx.y; + const int head_dim_idx = (blockIdx.z * blockDim.z + threadIdx.z) * VecSize; + + int load_idx, store_idx, cache_store_idx; + int rot_idx = token_idx * param.head_dim + head_dim_idx; + + const int batch_idx = *(batch_ids_q + token_idx); + const int inter_batch_token_offset = token_idx - *(cu_seqlens_q + batch_idx); + const int inter_batch_block_idx = inter_batch_token_offset / param.block_size; + const int inter_block_offset = inter_batch_token_offset % param.block_size; + const int block_idx = + *(block_tables + batch_idx * param.block_num + inter_batch_block_idx); + + assert(block_idx != -1); + + if (head_dim_idx < param.head_dim) { + if (head_idx < param.q_head_num) { // q + load_idx = token_idx * param.token_stride + + (head_idx + param.q_head_offset) * param.head_stride + + head_dim_idx; + store_idx = + token_idx * param.q_stride + head_idx * param.head_dim + head_dim_idx; + RotateQKVec(qkv, + rotary_cos, + rotary_sin, + load_idx, + store_idx, + -1, + rot_idx, + static_cast(nullptr), + q_out); + } + + if (head_idx < param.kv_head_num) { // kv + load_idx = token_idx * param.token_stride + + (head_idx + param.k_head_offset) * param.head_stride + + head_dim_idx; + store_idx = token_idx * param.kv_stride + head_idx * param.head_dim + + head_dim_idx; + cache_store_idx = block_idx * param.cache_stride + + inter_block_offset * param.kv_stride + + head_idx * param.head_dim + head_dim_idx; + // printf("block_idx: %d inter_block_offset: %d cache_store_idx: %d + // param.cache_stride: %d\n", block_idx, inter_block_offset, + // cache_store_idx, param.cache_stride); + RotateQKVec(qkv, + rotary_cos, + rotary_sin, + load_idx, + store_idx, + cache_store_idx, + rot_idx, + caches_k, + k_out); + + load_idx = token_idx * param.token_stride + + (head_idx + param.v_head_offset) * param.head_stride + + head_dim_idx; + StoreValue( + qkv, load_idx, store_idx, cache_store_idx, caches_v, v_out); + } + } +} + +template +void CacheKVWithRopeKernel( + const paddle::Tensor& qkv, // token_num, head_num * head_dim + paddle::Tensor& + caches_k, // max_block_num, block_size, kv_head_num, head_dim + paddle::Tensor& + caches_v, // max_block_num, block_size, kv_head_num, head_dim + const paddle::Tensor& block_tables, // bs, block_num + const paddle::Tensor& rotary_cos, + const paddle::Tensor& rotary_sin, + const paddle::Tensor& cu_seqlens_q, // bs + 1 + const paddle::Tensor& batch_ids_q, // token_num + const int q_head_num, + const int kv_head_num, + const int head_dim, + const int block_size, + paddle::Tensor& q_out, + paddle::Tensor& k_out, + paddle::Tensor& v_out) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const int all_num_elements = qkv.numel(); + const int all_num_heads = q_head_num + 2 * kv_head_num; + auto stream = qkv.stream(); + + dim3 block_dims(1, 4, (head_dim + VecSize - 1) / VecSize); + dim3 grid_dims(all_num_elements / (all_num_heads * head_dim), // token + (std::max(q_head_num, kv_head_num) + block_dims.y - 1) / + block_dims.y, // head + (head_dim + (block_dims.z * VecSize) - 1) / + (block_dims.z * VecSize) // dim: load Vec at a time + ); + + // printf("grid: (%d, %d, %d)\n", grid_dims.x, grid_dims.y, grid_dims.z); + // printf("block: (%d, %d, %d)\n", block_dims.x, block_dims.y, block_dims.z); + + CacheKVWithRopeParams param; + param.head_dim = head_dim; + param.block_size = block_size; + param.block_num = static_cast(block_tables.shape().back()); + param.cache_stride = block_size * kv_head_num * head_dim; + param.token_stride = all_num_heads * head_dim; + param.head_stride = head_dim; + param.q_stride = q_head_num * head_dim; + param.kv_stride = kv_head_num * head_dim; + param.q_head_offset = 0; + param.k_head_offset = q_head_num; + param.v_head_offset = q_head_num + kv_head_num; + param.q_head_num = q_head_num; + param.kv_head_num = kv_head_num; + + if (qkv.dtype() == rotary_cos.dtype()) { + DispatchCacheKVWithRopeVecKernel + <<>>( + reinterpret_cast(qkv.data()), + reinterpret_cast(caches_k.data()), + reinterpret_cast(caches_v.data()), + reinterpret_cast(block_tables.data()), + reinterpret_cast(rotary_cos.data()), + reinterpret_cast(rotary_sin.data()), + reinterpret_cast(cu_seqlens_q.data()), + reinterpret_cast(batch_ids_q.data()), + param, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data()), + reinterpret_cast(v_out.data())); + } else if (rotary_cos.dtype() == paddle::DataType::FLOAT32) { + DispatchCacheKVWithRopeVecKernel + <<>>( + reinterpret_cast(qkv.data()), + reinterpret_cast(caches_k.data()), + reinterpret_cast(caches_v.data()), + reinterpret_cast(block_tables.data()), + reinterpret_cast(rotary_cos.data()), + reinterpret_cast(rotary_sin.data()), + reinterpret_cast(cu_seqlens_q.data()), + reinterpret_cast(batch_ids_q.data()), + param, + reinterpret_cast(q_out.data()), + reinterpret_cast(k_out.data()), + reinterpret_cast(v_out.data())); + } else { + PD_THROW("Unsupported qk dtype and rope dtype."); + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + } +} + +std::vector CacheKVWithRope( + const paddle::Tensor& qkv, // token_num, head_num * head_dim + paddle::Tensor& + caches_k, // max_block_num, block_size, kv_head_num, head_dim + paddle::Tensor& + caches_v, // max_block_num, block_size, kv_head_num, head_dim + const paddle::Tensor& block_tables, // bs, block_num + const paddle::Tensor& rotary_cos, + const paddle::Tensor& rotary_sin, + const paddle::Tensor& cu_seqlens_q, // bs + 1 + const paddle::Tensor& batch_ids_q, // token_num + const int q_head_num, + const int kv_head_num, + const int head_dim, + const int block_size) { + auto qkv_shape = qkv.shape(); + auto token_num = qkv_shape[0]; + auto place = qkv.place(); + auto dtype = qkv.dtype(); + common::DDim q_out_shape, kv_out_shape; + if (rotary_cos.shape().size() == 3) { + q_out_shape = {token_num, q_head_num, head_dim}; + kv_out_shape = {token_num, kv_head_num, head_dim}; + } else { + q_out_shape = {token_num, 1, q_head_num, head_dim}; + kv_out_shape = {token_num, 1, kv_head_num, head_dim}; + } + auto q_out = GetEmptyTensor(q_out_shape, dtype, place); + auto k_out = GetEmptyTensor(kv_out_shape, dtype, place); + auto v_out = GetEmptyTensor(kv_out_shape, dtype, place); + + if (token_num == 0) { + return {q_out, k_out, v_out}; + } + + PADDLE_ENFORCE_EQ(qkv_shape.back(), + ((q_head_num + 2 * kv_head_num) * head_dim), + "The last dimension of qkv [%d] must equal to {(q_head_num " + "+ 2 * kv_head_num) * head_dim [%d].", + qkv_shape.back(), + ((q_head_num + 2 * kv_head_num) * head_dim)); + PADDLE_ENFORCE_EQ( + head_dim % 2, + 0, + "The last dimension (head_dim) of qkv must be an even number " + "for RoPE, but got %d", + head_dim); + PADDLE_ENFORCE_EQ(q_out.shape().back(), + rotary_cos.shape().back(), + "The last dimension of cos mismatches that of q, " + "expect %d but got %d", + q_out.shape().back(), + rotary_cos.shape().back()); + + switch (dtype) { + case paddle::DataType::BFLOAT16: + CacheKVWithRopeKernel(qkv, + caches_k, + caches_v, + block_tables, + rotary_cos, + rotary_sin, + cu_seqlens_q, + batch_ids_q, + q_head_num, + kv_head_num, + head_dim, + block_size, + q_out, + k_out, + v_out); + break; + case paddle::DataType::FLOAT16: + CacheKVWithRopeKernel(qkv, + caches_k, + caches_v, + block_tables, + rotary_cos, + rotary_sin, + cu_seqlens_q, + batch_ids_q, + q_head_num, + kv_head_num, + head_dim, + block_size, + q_out, + k_out, + v_out); + break; + default: + PD_THROW("Only support qk dtype of BF16 and F16"); + } + + return {q_out, k_out, v_out}; +} + +std::vector> CacheKVWithRopeInferShape( + const std::vector& qkv_shape, + const std::vector& caches_k_shape, + const std::vector& caches_v_shape, + const std::vector& block_tables_shape, + const std::vector& cos_shape, + const std::vector& sin_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& batch_ids_q_shape) { + return {qkv_shape, + caches_k_shape, + caches_v_shape, + block_tables_shape, + cos_shape, + sin_shape, + cu_seqlens_q_shape, + batch_ids_q_shape}; +} + +std::vector CacheKVWithRopeInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& caches_k_dtype, + const paddle::DataType& caches_v_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& cos_dtype, + const paddle::DataType& sin_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& batch_ids_q_dtype) { + return {qkv_dtype, + caches_k_dtype, + caches_v_dtype, + block_tables_dtype, + cos_dtype, + sin_dtype, + cu_seqlens_q_dtype, + batch_ids_q_dtype}; +} + +PD_BUILD_OP(cache_kv_with_rope) + .Inputs({"qkv", + "caches_k", + "caches_v", + "block_tables", + "rotary_cos", + "rotary_sin", + "cu_seqlen_q", + "batch_ids_q"}) + .Outputs({"q_out", "k_out", "v_out"}) + .Attrs( + {"q_head_num:int", "kv_head_num:int", "head_dim:int", "block_size:int"}) + .SetKernelFn(PD_KERNEL(CacheKVWithRope)) + .SetInferShapeFn(PD_INFER_SHAPE(CacheKVWithRopeInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(CacheKVWithRopeInferDtype)); diff --git a/custom_ops/metax_ops/fused_moe.cu b/custom_ops/metax_ops/fused_moe.cu index c1cdf14e7..30a134e0c 100644 --- a/custom_ops/metax_ops/fused_moe.cu +++ b/custom_ops/metax_ops/fused_moe.cu @@ -14,9 +14,10 @@ #pragma once -#include "fused_moe_op.h" +#include "fused_moe_helper.h" #include "helper.h" -#include "mc_fused_moe_helper.h" + +namespace phi { __global__ void compute_total_rows_before_expert_kernel( int* sorted_experts, @@ -42,58 +43,61 @@ void compute_total_rows_before_expert(int* sorted_indices, sorted_indices, total_indices, num_experts, total_rows_before_expert); } -template +} // namespace phi + +template void FusedMoeKernel(const paddle::Tensor& input, const paddle::Tensor& gate_weight, - const paddle::Tensor& ffn1_weight, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn1_bias, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn2_scale, - const paddle::optional& ffn2_bias, + const paddle::Tensor& up_gate_proj_weight, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& up_gate_proj_bias, + const paddle::Tensor& down_proj_weight, + const paddle::optional& down_proj_scale, + const paddle::optional& down_proj_bias, const std::string& quant_method, const int moe_topk, const bool group_moe, const bool norm_topk_prob, paddle::Tensor* output) { + using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto* output_data = output->data(); - auto moe_compute = - McMoeHelper(quant_method); + auto int8_moe_gemm_runner = McMoeGemmRunner(); - moe_compute.computeFFN(&input, - &gate_weight, - &ffn1_weight, - ffn1_scale ? ffn1_scale.get_ptr() : nullptr, - ffn1_bias ? ffn1_bias.get_ptr() : nullptr, - &ffn2_weight, - ffn2_scale ? ffn2_scale.get_ptr() : nullptr, - ffn2_bias ? ffn2_bias.get_ptr() : nullptr, - nullptr, - moe_topk, - group_moe, - norm_topk_prob, - 1.0, // ComputeFFN - "ffn", - output); + auto moe_compute = + McMoeHelper(quant_method, &int8_moe_gemm_runner); + + moe_compute.computeFFN( + &input, + &gate_weight, + &up_gate_proj_weight, + up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr, + up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr, + &down_proj_weight, + down_proj_scale ? down_proj_scale.get_ptr() : nullptr, + down_proj_bias ? down_proj_bias.get_ptr() : nullptr, + nullptr, + moe_topk, + group_moe, + norm_topk_prob, + 1.0, // ComputeFFN + "ffn", + output); } std::vector FusedExpertMoe( const paddle::Tensor& input, const paddle::Tensor& gate_weight, - const paddle::Tensor& ffn1_weight, - const paddle::Tensor& ffn2_weight, - const paddle::optional& ffn1_bias, - const paddle::optional& ffn1_scale, - const paddle::optional& ffn2_bias, - const paddle::optional& ffn2_scale, + const paddle::Tensor& up_gate_proj_weight, + const paddle::Tensor& down_proj_weight, + const paddle::optional& up_gate_proj_bias, + const paddle::optional& up_gate_proj_scale, + const paddle::optional& down_proj_bias, + const paddle::optional& down_proj_scale, const std::string& quant_method, const int moe_topk, const bool norm_topk_prob, @@ -107,40 +111,22 @@ std::vector FusedExpertMoe( switch (input_type) { case paddle::DataType::BFLOAT16: - FusedMoeKernel(input, - gate_weight, - ffn1_weight, - ffn1_scale, - ffn1_bias, - ffn2_weight, - ffn2_scale, - ffn2_bias, - quant_method, - moe_topk, - group_moe, - norm_topk_prob, - &output); + FusedMoeKernel(input, + gate_weight, + up_gate_proj_weight, + up_gate_proj_scale, + up_gate_proj_bias, + down_proj_weight, + down_proj_scale, + down_proj_bias, + quant_method, + moe_topk, + group_moe, + norm_topk_prob, + &output); break; - // case paddle::DataType::FLOAT16: - // FusedMoeKernel(input, - // gate_weight, - // ffn1_weight, - // ffn1_scale, - // ffn1_bias, - // ffn2_weight, - // ffn2_scale, - // ffn2_bias, - // quant_method, - // moe_topk, - // group_moe, - // norm_topk_prob, - // &output); - // break; default: - PD_THROW("Only support bf16 for FusedMoeKernel"); + PD_THROW("Unsupported data type for FusedMoeKernel"); } return {output}; } @@ -148,36 +134,36 @@ std::vector FusedExpertMoe( std::vector> FusedExpertMoeInferShape( const std::vector& input_shape, const std::vector& gate_weight_shape, - const std::vector& ffn1_weight_shape, - const std::vector& ffn2_weight_shape, - const paddle::optional>& ffn1_bias_shape, - const paddle::optional>& ffn1_scale_shape, - const paddle::optional>& ffn2_bias_shape, - const paddle::optional>& ffn2_scale_shape) { + const std::vector& up_gate_proj_weight_shape, + const std::vector& down_proj_weight_shape, + const paddle::optional>& up_gate_proj_bias_shape, + const paddle::optional>& up_gate_proj_scale_shape, + const paddle::optional>& down_proj_bias_shape, + const paddle::optional>& down_proj_scale_shape) { return {input_shape}; } std::vector FusedExpertMoeInferDtype( const paddle::DataType& input_dtype, const paddle::DataType& gate_weight_dtype, - const paddle::DataType& ffn1_weight_dtype, - const paddle::DataType& ffn2_weight_dtype, - const paddle::optional& ffn1_bias_dtype, - const paddle::optional& ffn1_scale_dtype, - const paddle::optional& ffn2_bias_dtype, - const paddle::optional& ffn2_scale_dtype) { + const paddle::DataType& up_gate_proj_weight_dtype, + const paddle::DataType& down_proj_weight_dtype, + const paddle::optional& up_gate_proj_bias_dtype, + const paddle::optional& up_gate_proj_scale_dtype, + const paddle::optional& down_proj_bias_dtype, + const paddle::optional& down_proj_scale_dtype) { return {input_dtype}; } -PD_BUILD_OP(fused_expert_moe) +PD_BUILD_STATIC_OP(fused_expert_moe) .Inputs({"input", "gate_weight", - "ffn1_weight", - "ffn2_weight", - paddle::Optional("ffn1_bias"), - paddle::Optional("ffn1_scale"), - paddle::Optional("ffn2_bias"), - paddle::Optional("ffn2_scale")}) + "up_gate_proj_weight", + "down_proj_weight", + paddle::Optional("up_gate_proj_bias"), + paddle::Optional("up_gate_proj_scale"), + paddle::Optional("down_proj_bias"), + paddle::Optional("down_proj_scale")}) .Outputs({"output"}) .Attrs({"quant_method:std::string", "moe_topk:int", diff --git a/custom_ops/metax_ops/fused_moe_gemm_kernels.h b/custom_ops/metax_ops/fused_moe_gemm_kernels.h new file mode 100644 index 000000000..177f81307 --- /dev/null +++ b/custom_ops/metax_ops/fused_moe_gemm_kernels.h @@ -0,0 +1,199 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mctlass/numeric_conversion.h" +#include "mctlassEx/mctlassEx.h" + +namespace phi { + +template +struct mctlassExDataTraits; + +template <> +struct mctlassExDataTraits { + static constexpr mctlassExDataType type = + mctlassExDataType::MCTLASS_EX_DATATYPE_BF16; +}; + +template <> +struct mctlassExDataTraits { + static constexpr mctlassExDataType type = + mctlassExDataType::MCTLASS_EX_DATATYPE_INT8; +}; + +template +class McMoeGemmRunner { + public: + McMoeGemmRunner() {} + + void mc_grouped_gemm_basic_kernel(const T* ptrA, + mctlassExOrder_t majorA, + const WeightType* ptrB, + mctlassExOrder_t majorB, + const T* ptrScale, + const T* ptrBias, + T* ptrC, + mctlassExOrder_t majorC, + const int* ptrSegInd, + int* ptrMNumTilesInd, + int numExperts, + int m, // expanded_active_expert_rows + int n, // inter_dim + int k, // hidden_size + mcStream_t stream) { + mctlassExHandle_t handle; + mctlassExHandleCreate(&handle); + + mctlassExDataType DataType_ = mctlassExDataTraits::type; + mctlassExDataType WeightType_ = mctlassExDataTraits::type; + + mctlassExMatrixLayout_t matLayoutA; + mctlassExMatrixLayout_t matLayoutB; + mctlassExMatrixLayout_t matLayoutC; + + // mat A: (m, k) + mctlassExMatrixLayoutCreate(&matLayoutA, DataType_, m, k, k); + mctlassExMatrixLayoutSetAttribute( + matLayoutA, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorA, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutA, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); + // mat B: (num_experts, n, k) + mctlassExMatrixLayoutCreate(&matLayoutB, WeightType_, k, n, k); + mctlassExMatrixLayoutSetAttribute( + matLayoutB, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorB, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutB, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); + // mat C: (m, n) + mctlassExMatrixLayoutCreate(&matLayoutC, DataType_, m, n, n); + mctlassExMatrixLayoutSetAttribute( + matLayoutC, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_ORDER, + &majorC, + sizeof(mctlassExOrder_t)); + mctlassExMatrixLayoutSetAttribute( + matLayoutC, + mctlassExMatrixLayoutAttribute_t::MCTLASS_EX_MATRIX_LAYOUT_BATCH_COUNT, + &numExperts, + sizeof(int)); + // bias: (num_experts, n) + // scale: (num, n) + + mctlassExDesc_t mctlass_desc; + mctlassExCreateDesc(&mctlass_desc); + mctlassExDataType input_type = DataType_; + mctlassExDataType scale_type = WeightType_; + mctlassExDataType compute_type = + mctlassExDataType::MCTLASS_EX_DATATYPE_FP32; + mctlassExEpilogueType epilogue_type = + mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_DEFAULT; + if (ptrBias) { + epilogue_type = mctlassExEpilogueType::MCTLASS_EX_EPILOGUE_TYPE_BIAS; + } + // set scale + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_POINTER, + &ptrScale, + sizeof(ptrScale)); + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_B_SCALE_TYPE, + &input_type, + sizeof(mctlassExDataType)); + // set bias + if (ptrBias) { + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_BIAS_POINTER, + &ptrBias, + sizeof(ptrBias)); + } + // set coumpute type + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_COMPUTE_TYPE, + &compute_type, + sizeof(mctlassExDataType)); + // set epilogue type + mctlassExDescSetAttribute( + mctlass_desc, + mctlassExDescAttributes_t::MCTLASS_EX_DESC_EPILOGUE_TYPE, + &epilogue_type, + sizeof(mctlassExEpilogueType)); + + const mctlassExContiguousGroupedGemmAlgo_t algo = + mctlassExContiguousGroupedGemmAlgo_t:: + MCTLASS_EX_CONTIGUOUS_GROUPED_ALGO_DEFAULT; + mctlassExContiguousGroupedDesc_t contiguous_group_desc; + mctlassExContiguousGroupedDescCreate( + &contiguous_group_desc, ptrSegInd, nullptr, ptrMNumTilesInd, 1); + int blocksizeM; + mctlassExContiguousGroupedGemmGetBlocksizeM(handle, + mctlass_desc, + matLayoutA, + matLayoutB, + matLayoutC, + &algo, + &blocksizeM); + mctlassExContiguousGroupedGemmComputeMNumTilesIndptr(handle, + mctlass_desc, + matLayoutA, + matLayoutB, + matLayoutC, + &algo, + contiguous_group_desc, + numExperts, + blocksizeM, + stream); + + mctlassExContiguousGroupedGemmBasic(handle, + mctlass_desc, + ptrA, + matLayoutA, + ptrB, + matLayoutB, + ptrC, + matLayoutC, + contiguous_group_desc, + &algo, + nullptr, + 0, + stream); + + mctlassExHandleDestroy(handle); + mctlassExMatrixLayoutDestroy(matLayoutA); + mctlassExMatrixLayoutDestroy(matLayoutB); + mctlassExMatrixLayoutDestroy(matLayoutC); + mctlassExContiguousGroupedDescDestroy(contiguous_group_desc); + mctlassExDestroyDesc(mctlass_desc); + } +}; + +template class McMoeGemmRunner; + +} // namespace phi diff --git a/custom_ops/metax_ops/fused_moe_helper.h b/custom_ops/metax_ops/fused_moe_helper.h index 67c616ce4..9d7842ce2 100644 --- a/custom_ops/metax_ops/fused_moe_helper.h +++ b/custom_ops/metax_ops/fused_moe_helper.h @@ -14,14 +14,17 @@ #pragma once -#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h" +#include "fused_moe_gemm_kernels.h" +#include "fused_moe_imp_op.h" #include "fused_moe_op.h" +#include "mctlass/numeric_conversion.h" +#include "mctlassEx/mctlassEx.h" -using namespace phi; +namespace phi { template -__global__ void moe_token_type_ids_kernel(T *gating_output, - const int *moe_token_type_ids_out, +__global__ void moe_token_type_ids_kernel(T* gating_output, + const int* moe_token_type_ids_out, const int num_rows, const int num_experts, const int k) { @@ -40,8 +43,8 @@ __global__ void moe_token_type_ids_kernel(T *gating_output, } template -void moe_token_type_ids_kernelLauncher(T *gating_output, - const int *moe_token_type_ids_out, +void moe_token_type_ids_kernelLauncher(T* gating_output, + const int* moe_token_type_ids_out, const int num_rows, const int num_experts, const int k, @@ -51,3 +54,338 @@ void moe_token_type_ids_kernelLauncher(T *gating_output, moe_token_type_ids_kernel<<>>( gating_output, moe_token_type_ids_out, num_rows, num_experts, k); } + +template +class McMoeHelper { + public: + McMoeHelper(const std::string gemm_method, + McMoeGemmRunner* int8_moe_gemm_runner) + : gemm_method_(gemm_method), + int8_moe_gemm_runner_(int8_moe_gemm_runner) {} + + // -------- getWorkspaceSize -------- // + template + size_t getWorkspaceSize(const int64_t num_rows, + const int64_t hidden_size, + const int64_t inter_size, + const int64_t num_experts, + const int64_t k) { + const size_t buf_size = AlignTo16(k * num_rows * hidden_size); + const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const size_t padded_experts = AlignTo16(num_experts); + const size_t num_moe_inputs = AlignTo16(k * num_rows); + // softmax output, permuted_rows and permuted_experts have moved to outside + // of moe kernel, allocate them in Encoder or Decoder before invoking + // FfnLayer forward. + size_t total_ws_bytes = + 5 * num_moe_inputs * + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + total_ws_bytes += + padded_experts * sizeof(int32_t); // Hold total_rows_before_expert_ + + const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); + const size_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(num_rows)); + sorter_.update_num_experts(num_experts); + + int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; + if (sorter_ws_size_bytes > bytes_for_fc1_result) { + int64_t remaining_bytes = + AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result); + bytes_for_intermediate_and_sorting += remaining_bytes; + } + + total_ws_bytes += + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace + + int64_t num_softmax_outs = 0; + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + num_softmax_outs = AlignTo16(num_rows * num_experts); + } + + total_ws_bytes += num_softmax_outs * sizeof(float); + + return total_ws_bytes; + } + + void computeFFN(const paddle::Tensor* input, + const paddle::Tensor* gate_weight, + const paddle::Tensor* up_gate_proj_weight, + const paddle::Tensor* up_gate_proj_scale, + const paddle::Tensor* up_gate_proj_bias, + const paddle::Tensor* down_proj_weight, + const paddle::Tensor* down_proj_scale, + const paddle::Tensor* down_proj_bias, + const paddle::Tensor* moe_token_type_ids, + const int moe_topk, + const bool group_moe, + const bool norm_topk_prob, + const float routed_scaling_factor, + const std::string moe_type, + paddle::Tensor* output) { + auto* input_activations = input->data(); + auto* gating_weights = gate_weight->data(); + const T* fc1_expert_biases = + up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; + const T* fc2_expert_biases = + down_proj_bias ? down_proj_bias->data() : nullptr; + + auto* output_ = output->data(); + auto stream = input->stream(); + auto place = input->place(); + auto input_type = input->dtype(); + + auto input_dims = input->dims(); + auto up_gate_proj_dims = up_gate_proj_weight->dims(); + int64_t token_num = 0; + if (input_dims.size() == 3) { + token_num = input_dims[0] * input_dims[1]; + } else { + token_num = input_dims[0]; + } + const int64_t num_rows = token_num; + + const int64_t hidden_size = up_gate_proj_dims[2]; + int64_t inter_dim = 0; + if (moe_type == "qkv") { + inter_dim = + up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; + } else { + inter_dim = up_gate_proj_dims[1]; + } + + // if (gemm_method_ == "weight_only_int4") { + // inter_dim = inter_dim * 2; + // } + + const int64_t inter_size = inter_dim; + const int64_t num_experts = up_gate_proj_dims[0]; + const int64_t k = moe_topk; + + int64_t bytes = + getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, k); + + // Pointers + int* expert_for_source_row; + int* source_rows_; + int* permuted_rows_; + int* permuted_experts_; + int* expanded_source_row_to_expanded_dest_row; + + T* permuted_data_; + int32_t* total_rows_before_expert_; + T* fc1_result_; + float* softmax_out_; + + paddle::Tensor ws_ptr_tensor = + GetEmptyTensor({bytes}, paddle::DataType::INT8, place); + int8_t* ws_ptr = ws_ptr_tensor.data(); + + const int64_t buf_size = AlignTo16(k * num_rows * hidden_size); + const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size); + const int64_t padded_experts = AlignTo16(num_experts); + const int64_t num_moe_inputs = AlignTo16(k * num_rows); + + expert_for_source_row = reinterpret_cast(ws_ptr); + source_rows_ = expert_for_source_row + num_moe_inputs; + permuted_rows_ = source_rows_ + num_moe_inputs; + permuted_experts_ = permuted_rows_ + num_moe_inputs; + expanded_source_row_to_expanded_dest_row = + permuted_experts_ + num_moe_inputs; + permuted_data_ = reinterpret_cast( + expanded_source_row_to_expanded_dest_row + num_moe_inputs); + total_rows_before_expert_ = + reinterpret_cast(permuted_data_ + buf_size); + fc1_result_ = + reinterpret_cast(total_rows_before_expert_ + padded_experts); + + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if (!is_pow_2 || num_experts > 256) { + softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + } else { + softmax_out_ = nullptr; + } + + paddle::Tensor expert_scales_float_tensor = + GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + float* expert_scales_float = expert_scales_float_tensor.data(); + + float* softmax_max_prob = nullptr; + if (group_moe) { + paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor( + {num_rows, moe_topk}, paddle::DataType::FLOAT32, place); + // (TODO: check fill success ?) + paddle::experimental::fill(softmax_max_prob_tensor, 0.f); + softmax_max_prob = softmax_max_prob_tensor.data(); + } + + paddle::Tensor fc1_out_tensor = + GetEmptyTensor({num_rows * k, inter_size}, input_type, place); + T* fc1_out = fc1_out_tensor.data(); + + auto input_cast_tensor = + paddle::experimental::cast(*input, paddle::DataType::FLOAT32); + auto gate_tensor = + paddle::experimental::matmul(input_cast_tensor, *gate_weight); + float* gating_output = gate_tensor.data(); + + if (moe_token_type_ids) { + auto* moe_token_type_ids_out = moe_token_type_ids->data(); + moe_token_type_ids_kernelLauncher(gating_output, + moe_token_type_ids_out, + num_rows, + num_experts, + k, + stream); + } + + topk_gating_softmax_kernelLauncher(gating_output, + nullptr, + expert_scales_float, + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + num_experts, + k, + group_moe, + stream); + + const int64_t sorter_ws_size_bytes = + AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); + + sorter_.run(fc1_result_, + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + k * num_rows, + false, + stream); + + initialize_moe_routing_kernelLauncher( + input_activations, + permuted_data_, + permuted_rows_, + nullptr, + nullptr, + expanded_source_row_to_expanded_dest_row, + num_rows, + num_rows, + hidden_size, + k, + stream); + + const int64_t expanded_active_expert_rows = k * num_rows; + + compute_total_rows_before_expert(permuted_experts_, + expanded_active_expert_rows, + num_experts, + total_rows_before_expert_, + stream); + + mctlassExOrder_t row_major = mctlassExOrder_t::MCTLASS_EX_ORDER_ROW_MAJOR; + mctlassExOrder_t column_major = + mctlassExOrder_t::MCTLASS_EX_ORDER_COLUMN_MAJOR; + auto m_num_tile = + GetEmptyTensor({num_experts}, paddle::DataType::INT32, place); + int* m_num_tile_ptr = reinterpret_cast(m_num_tile.data()); + + if (gemm_method_ == "weight_only_int8") { + int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel( + reinterpret_cast(permuted_data_), + row_major, + reinterpret_cast(up_gate_proj_weight->data()), + column_major, + reinterpret_cast(up_gate_proj_scale->data()), + reinterpret_cast(fc1_expert_biases), + reinterpret_cast(fc1_out), + row_major, + total_rows_before_expert_, + m_num_tile_ptr, + num_experts, + expanded_active_expert_rows, + inter_size, + hidden_size, + stream); + } else { + throw std::runtime_error("Unsupported gemm method: " + gemm_method_); + } + + if (moe_type == "ffn") { + auto act_out_tensor = + paddle::experimental::swiglu(fc1_out_tensor, nullptr); + auto act_out = act_out_tensor.data(); + + paddle::Tensor fc2_output_tensor = + GetEmptyTensor({k * num_rows, hidden_size}, input_type, place); + T* fc2_result = fc2_output_tensor.data(); + + if (gemm_method_ == "weight_only_int8") { + int8_moe_gemm_runner_->mc_grouped_gemm_basic_kernel( + reinterpret_cast(act_out), + row_major, + reinterpret_cast(down_proj_weight->data()), + column_major, + reinterpret_cast(down_proj_scale->data()), + nullptr, + reinterpret_cast(fc2_result), + row_major, + total_rows_before_expert_, + m_num_tile_ptr, + num_experts, + expanded_active_expert_rows, + hidden_size, + inter_size / 2, + stream); + } else { + throw std::runtime_error("Unsupported gemm method: " + gemm_method_); + } + + finalize_moe_routing_kernelLauncher( + fc2_result, + output_, + fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + hidden_size, + k, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); + } else { + finalize_moe_routing_kernelLauncher( + // fc2_result, + fc1_out, + output_, + fc1_expert_biases, // fc2_expert_biases, + reinterpret_cast(expert_scales_float), + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + inter_size, + k, + static_cast(0), + norm_topk_prob, + routed_scaling_factor, + stream); + } + } + + private: + McMoeGemmRunner* int8_moe_gemm_runner_; + std::string gemm_method_; + CubKeyValueSorter sorter_; +}; + +} // namespace phi diff --git a/custom_ops/metax_ops/fused_moe_imp_op.h b/custom_ops/metax_ops/fused_moe_imp_op.h index 99aabaf8a..3108df789 100644 --- a/custom_ops/metax_ops/fused_moe_imp_op.h +++ b/custom_ops/metax_ops/fused_moe_imp_op.h @@ -20,6 +20,8 @@ #include #include "cub/cub.cuh" +namespace phi { + static const float HALF_FLT_MAX = 65504.F; static const float HALF_FLT_MIN = -65504.F; static inline size_t AlignTo16(const size_t& input) { @@ -121,3 +123,5 @@ class CubKeyValueSorter { int num_experts_; int num_bits_; }; + +} // namespace phi diff --git a/custom_ops/metax_ops/fused_moe_op.h b/custom_ops/metax_ops/fused_moe_op.h index 00ed38115..1a7d32cdb 100644 --- a/custom_ops/metax_ops/fused_moe_op.h +++ b/custom_ops/metax_ops/fused_moe_op.h @@ -1,28 +1,27 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & - * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// /* +// * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & +// * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ #pragma once #include #include -#include "fused_moe_helper.h" -#include "fused_moe_imp_op.h" -#include "mctlass/numeric_conversion.h" // BUILD_MARK -// Ignore mctlass warnings about type punning +#include "mctlass/functional.h" +#include "mctlass/numeric_conversion.h" +// Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wunused-function" @@ -34,6 +33,8 @@ #define WARP_SIZE 32 +namespace phi { + struct GpuLaunchConfig { dim3 block_per_grid; dim3 thread_per_block; @@ -55,6 +56,324 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { return config; } +constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; +template +__host__ __device__ constexpr static U arrayConvert(T const& input) { + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; +#pragma unroll + for (int i = 0; i < U::kElements; i++) { + u[i] = static_cast(input[i]); + } + return u; +} + +struct uint8 { + uint4 u; + uint4 v; +}; + +template +struct BytesToType {}; + +template <> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +template