// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ #ifndef ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ #define ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ #include #include "cute/algorithm/copy.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" namespace mla_attn { using namespace cute; template struct alignas(16) SharedStorageQKVO { alignas(16) cute::array_aligned> smem_q; alignas(16) cute::array_aligned> smem_p; alignas(16) cute::array_aligned> smem_scale; union { alignas(16) cute::array_aligned> smem_kv; alignas(16) cute::array_aligned> smem_o; }; struct { alignas(16) typename MainloopPipelineQ::SharedStorage pipeline_q; alignas(16) typename MainloopPipeline::SharedStorage pipeline_kv; }; }; template struct AttentionKernelTraits { using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; using IdType = IdType_; using DTypeQKAccum = float; using DTypePVAccum = float; using NV_TYPE = NV_TYPE_; static constexpr bool USE_TMA_LOAD_KV = USE_TMA_LOAD_KV_; static constexpr int GROUP_SIZE = GROUP_SIZE_; static constexpr int BLOCK_SHAPE_Q = BLOCK_SHAPE_Q_; static_assert(BLOCK_SHAPE_Q % 64 == 0); static constexpr int BLOCK_SHAPE_KV = BLOCK_SHAPE_KV_; static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_; static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_; static constexpr int NUM_PER_STAGE = BLOCK_SHAPE_KV * HEAD_DIM_QK; static_assert(HEAD_DIM_QK % 32 == 0); static_assert(HEAD_DIM_VO % 32 == 0); static constexpr int NUM_WARPS = 12; static constexpr int NUM_THREADS = 384; static constexpr int NUM_PRODUCER_THREADS = 128; using TileShape_QKD = Shape, Int, Int>; using TileShape_PDV = Shape, Int, Int>; static constexpr int NUM_STAGES = NUM_STAGES_; using AtomLayoutQKD = Layout, _1, _1>>; using AtomLayoutPV = Layout, _2, _1>>; using TiledMmaQK = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); using TiledMmaPV = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutPV{})); using TiledMmaPVSS = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutPV{})); static constexpr int NUM_MMA_THREADS = size(TiledMmaPV{}); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); using SmemLayoutVt = decltype(composition( SmemLayoutK{}, make_ordered_layout(make_shape(get<2>(TileShape_QKD{}), get<1>(TileShape_QKD{}), Int{}), Step<_2, _1, _3>{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int<1>{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutVtOneStage = decltype(composition( SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}), get<2>(TileShape_PDV{}), Int<1>{}), Step<_2, _1, _3>{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); using SmemCopyAtom = Copy_Atom; static constexpr bool IS_CTA_32 = (BLOCK_SHAPE_KV == 32); using SmemLayoutRowOneStage = Layout>, Stride<_1, _2>>; using SmemLayoutRowTwoStage = Layout, _2>, Stride<_1, _2, _256>>; using SmemLayoutRow = std::conditional_t; using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), decltype(cute::get<1>(TileShape_QKD{}))>()); using SmemLayoutPSSOneStage = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_QKD{}))); using SmemLayoutPSSTwoStage = decltype(tile_to_shape(SmemLayoutAtomP{}, make_shape(Int{}, Int{}, Int<2>{}))); using SmemLayoutP = std::conditional_t; using MainloopPipelineQ = typename cutlass::PipelineAsync<1>; using PipelineStateQ = typename cutlass::PipelineState<1>; using MainloopPipeline = std::conditional_t, typename cutlass::PipelineAsync>; using PipelineState = typename cutlass::PipelineState; using SharedStorage = SharedStorageQKVO; }; } // namespace mla_attn #endif