/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include // For cute::elect_one_sync() #include #include #include #include using namespace cute; template struct PackedHalf; template<> struct PackedHalf { using Type = __half2; }; template<> struct PackedHalf { using Type = nv_bfloat162; }; template __forceinline__ __device__ auto float_2_half2(const float x) { if constexpr (std::is_same::value) { return __float2half2_rn(x); } else { return __float2bfloat162_rn(x); } } struct uint16 { uint4 u; uint4 v; uint4 s; uint4 t; }; struct uint8 { uint4 u; uint4 v; }; template struct BytesToType {}; template<> struct BytesToType<64> { using Type = uint16; static_assert(sizeof(Type) == 64); }; template<> struct BytesToType<32> { using Type = uint8; static_assert(sizeof(Type) == 32); }; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; template struct Vec { enum { BYTES = NUM_ELT * sizeof(Elt_type) }; using Vec_type = typename BytesToType::Type; using Alias_type = union { Vec_type vec; Elt_type elt[NUM_ELT]; }; Alias_type data; inline __device__ Vec() {} template inline __device__ void to(Vec &other) { #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { other.data.elt[it] = S(this->data.elt[it]); } } template inline __device__ void assign(const Op &op) { #pragma unroll for( int it = 0; it < NUM_ELT; it++ ) { this->data.elt[it] = op(it); } } inline __device__ void load_from(const void *base_ptr) { this->data.vec = *reinterpret_cast(base_ptr); } inline __device__ void store_to(void *base_ptr) { *reinterpret_cast(base_ptr) = this->data.vec; } inline __device__ void add(const Vec &other) { static_assert(NUM_ELT % 2 == 0); using type = typename PackedHalf::Type; #pragma unroll for (int it = 0; it < NUM_ELT / 2; it++) { type b = *reinterpret_cast(other.data.elt + it * 2); *reinterpret_cast(this->data.elt + it * 2) += b; } } inline __device__ void fma(const Vec &scale, const Vec &bias) { static_assert(NUM_ELT % 2 == 0); using type = typename PackedHalf::Type; #pragma unroll for (int it = 0; it < NUM_ELT / 2; it++) { type a = *reinterpret_cast(scale.data.elt + it * 2); type b = *reinterpret_cast(bias.data.elt + it * 2); *reinterpret_cast(this->data.elt + it * 2) += a * b; } } inline __device__ void set_zero() { constexpr int size = sizeof(Vec_type) / sizeof(int); #pragma unroll for (int i = 0; i < size; ++i) { (reinterpret_cast(this->data.elt))[i] = 0; } } }; template inline __device__ void apply_rotary_embedding(Vec& vec, Vec& cos, Vec& sin) { static_assert(PackSize % 2 == 0); #pragma unroll for (int i = 0; i < PackSize / 2; i++) { const float cos_inv_freq = cos.data.elt[i]; const float sin_inv_freq = sin.data.elt[i]; const float v1 = static_cast(vec.data.elt[2 * i]); const float v2 = static_cast(vec.data.elt[2 * i + 1]); vec.data.elt[2 * i] = static_cast(cos_inv_freq * v1 - sin_inv_freq * v2); vec.data.elt[2 * i + 1] = static_cast(sin_inv_freq * v1 + cos_inv_freq * v2); } } template __forceinline__ __device__ void app_mask( Tensor &tSrS, const int *mask, const int &mask_row_id, const int &col_base) { const float mask_value = -1000000.0f; for (int i = 0; i < size(tSrS); i+=8) { const int col = i * 2 + col_base; if (col >= mask[mask_row_id]) { tSrS(i) = mask_value; } if (col + 1 >= mask[mask_row_id]) { tSrS(i + 1) = mask_value; } if (col >= mask[mask_row_id + 8]) { tSrS(i + 2) = mask_value; } if (col + 1 >= mask[mask_row_id + 8]) { tSrS(i + 3) = mask_value; } if (col + 8 >= mask[mask_row_id]) { tSrS(i + 4) = mask_value; } if (col + 9 >= mask[mask_row_id]) { tSrS(i + 5) = mask_value; } if (col + 8 >= mask[mask_row_id + 8]) { tSrS(i + 6) = mask_value; } if (col + 9 >= mask[mask_row_id + 8]) { tSrS(i + 7) = mask_value; } } } template struct HalfMax; template<> struct HalfMax { inline __device__ __half2 operator()(const __half2 x, const __half2 y) { __half2 res; asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(*reinterpret_cast(&res)) : "r"(*reinterpret_cast(&x)), "r"(*reinterpret_cast(&y))); return res; } }; template<> struct HalfMax { inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { nv_bfloat162 res; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(*reinterpret_cast(&res)) : "r"(*reinterpret_cast(&x)), "r"(*reinterpret_cast(&y))); return res; } }; template struct HalfMin; template<> struct HalfMin { inline __device__ __half2 operator()(const __half2 x, const __half2 y) { __half2 res; asm volatile("min.f16x2 %0, %1, %2;\n" : "=r"(*reinterpret_cast(&res)) : "r"(*reinterpret_cast(&x)), "r"(*reinterpret_cast(&y))); return res; } }; template<> struct HalfMin { inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { nv_bfloat162 res; asm volatile("min.bf16x2 %0, %1, %2;\n" : "=r"(*reinterpret_cast(&res)) : "r"(*reinterpret_cast(&x)), "r"(*reinterpret_cast(&y))); return res; } }; template __forceinline__ __device__ void copy( TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, const int max_MN = 0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } } } } template inline __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } template __inline__ __device__ T BlockAllReduce(T val) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ T result_broadcast; T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); if (threadIdx.x == 0) { result_broadcast = result; } __syncthreads(); return result_broadcast; } template __inline__ __device__ T BlockScanSum(T val) { typedef cub::BlockScan BlockScanT; __shared__ typename BlockScanT::TempStorage temp_storage; int aggregate; BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate); __syncthreads(); return val; } template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; template struct MinOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; } }; template <> struct MinOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); } }; template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); if constexpr (mma_shape_K == 8) { return acc_layout; } else { auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } } }; template __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { constexpr bool Is_RS = !cute::is_base_of::value; // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } warpgroup_fence_operand(tCrC); if constexpr (arrive) { warpgroup_arrive(); } if constexpr (zero_init) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } else { // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); tiled_mma.accumulate_ = GMMA::ScaleOut::One; } } if constexpr (commit) { warpgroup_commit_batch(); } if constexpr (wg_wait >= 0) { warpgroup_wait(); } warpgroup_fence_operand(tCrC); if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } } template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 static_assert(decltype(size<0, 0>(acc_layout))::value == 2); static_assert(decltype(size<0, 1>(acc_layout))::value == 2); static_assert(decltype(rank(acc_layout))::value == 3); auto l = acc_layout; return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); } else { // SM80 static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); } }; template __inline__ __device__ T WarpAllReduce(T val) { ReductionOp op; #pragma unroll for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = op(val, __shfl_xor_sync(0xffffffff, val, mask)); } return val; }