/****************************************************************************** * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ #pragma once #include "cute/atom/mma_atom.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" using namespace cute; struct Flash_mask_params { void *__restrict__ q_ptr; void *__restrict__ k_ptr; void *__restrict__ v_ptr; void * __restrict__ o_ptr; int * __restrict__ cu_seq_q; int * __restrict__ cu_seq_k; int * __restrict__ mask; int * seq_len_encoder; int head_num; int kv_head_num; int max_seq_len_q; int max_seq_len_k; int batch_size; int gqa_group_size; float scale_softmax_log2; }; template struct SharedStorageQKVO { cute::array_aligned> smem_q; cute::array_aligned> smem_k; union { cute::array_aligned> smem_v; cute::array_aligned> smem_o; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; }; }; template struct Flash_mask_kernel_traits { using Element = elem_type; using ElementAccum = float; using index_t = int32_t; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; static_assert(kHeadDim % 32 == 0); using TileShape_MNK = Shape, Int, Int>; using ClusterShape_MNK = Shape, Int<1>, Int<1>>; static constexpr int kStages = kStages_; static constexpr int NeedMask = NeedMask_; using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutMNK{})); using TiledMma1 = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, GMMA::Major::MN>(), AtomLayoutMNK{})); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); using SmemCopyAtomQ = Copy_Atom; using SmemCopyAtomO = Copy_Atom; using SharedStorage = SharedStorageQKVO; static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; static_assert(NumMmaThreads % kNumThreadsPerRow == 0); static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; using TiledCopyOAtom = cute::Copy_Atom, Element>; using TiledCopyOThrLayout = decltype(cute::make_layout( cute::make_shape(Int{}, Int{}), LayoutRight{})); using TiledCopyOValLayout = decltype(cute::make_layout( cute::make_shape(_1{}, Int{}), LayoutRight{})); using GmemTiledCopyO = decltype(make_tiled_copy( TiledCopyOAtom{}, TiledCopyOThrLayout{}, TiledCopyOValLayout{} )); using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; };