// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ #ifndef ATTENTION_HOPPER_PREFILL_SM90_CUH_ #define ATTENTION_HOPPER_PREFILL_SM90_CUH_ #include #include #include #include #include #include #include #include #include #include "attention_updater.cuh" #include "cute/tensor.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "epilogue.cuh" #include "helper.h" #include "kernel_traits.cuh" #include "mainloop_mma.cuh" #include "mainloop_load.cuh" #include "utils.cuh" #ifdef DEBUG_MLA #undef DEBUG_MLA #endif // #define DEBUG_MLA namespace mla_attn { using namespace cute; template struct Params { using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; using IdType = IdType_; alignas(16) DTypeQ *Q; // [token_num, head_num, dim_head] alignas(16) DTypeKV *KV; // [max_block_num, block_size, dim_head] alignas(16) DTypeO *O; // [token_num, head_num, dim_head] alignas(16) DTypeO *O_tmp; // [max_num_chunks, bsz, head_num, dim_head] alignas(16) float *m; // [max_num_chunks, bsz * max_draft_token_num * head_num] alignas(16) float *d; // [max_num_chunks, bsz * max_draft_token_num * head_num] alignas(16) IdType *block_tables; alignas(16) IdType *seq_lens_this_time; alignas(16) IdType *seq_lens_decoder; alignas(16) IdType *cumsum_q_seqlens; alignas(16) IdType *batch_id_per_token; alignas(16) IdType *batch_ids; alignas(16) IdType *tile_ids_per_batch; alignas(16) IdType *num_blocks_x; alignas(16) IdType *chunk_size_device; uint32_t q_stride_bsz; uint32_t q_stride_head_num; uint32_t kv_stride_block_num; uint32_t kv_stride_block_size; uint32_t o_stride_bsz; uint32_t o_stride_head_num; int bsz; int token_num; int max_block_num; int max_block_num_per_seq; int q_num_head; int qk_head_dim; int vo_head_dim; int block_size; int max_draft_token_num; int chunk_num; float sm_scale; }; #define DISPATCH_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ } else if (group_size == 16) { \ constexpr size_t GROUP_SIZE = 16; \ __VA_ARGS__ \ } else if (group_size == 64) { \ constexpr size_t GROUP_SIZE = 64; \ __VA_ARGS__ \ } else { \ PD_THROW("not support the group_size: ", group_size); \ return cudaErrorNotSupported; \ } template __global__ void __launch_bounds__(Ktraits::NUM_WARPS * cutlass::NumThreadsPerWarp, 1) MLAWithKVCacheKernel(CUTE_GRID_CONSTANT typename CollectiveMainloop::Params const mainloop_params, CUTE_GRID_CONSTANT typename CollectiveEpilogue::Params const epilogue_params) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using DTypeO = typename Ktraits::DTypeO; using DTypeQKAccum = typename Ktraits::DTypeQKAccum; using TileShape_QKD = typename Ktraits::TileShape_QKD; using TileShape_PDV = typename Ktraits::TileShape_PDV; static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; static constexpr int NUM_COPY_THREADS = Ktraits::NUM_PRODUCER_THREADS; static constexpr int BLOCK_SHAPE_Q = Ktraits::BLOCK_SHAPE_Q; static constexpr int BLOCK_SHAPE_KV = Ktraits::BLOCK_SHAPE_KV; const int num_blocks_x = mainloop_params.num_blocks_x[0]; const int chunk_size = mainloop_params.chunk_size_device[0]; static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV; using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ; using PipelineParamsQ = typename MainloopPipelineQ::Params; using PipelineStateQ = typename MainloopPipelineQ::PipelineState; extern __shared__ char shared_memory[]; auto& shared_storage = *reinterpret_cast(shared_memory); int const lane_predicate = cute::elect_one_sync(); int const warp_idx = cutlass::canonical_warp_idx_sync(); if (warp_idx == 0 && lane_predicate) { CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); } // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; PipelineParams pipeline_params; int warp_group_idx = cutlass::canonical_warp_group_idx(); pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer : MainloopPipeline::ThreadCategory::Consumer; if constexpr (use_tma_load_kv) { pipeline_params.is_leader = warp_group_thread_idx == 0; pipeline_params.num_consumers = NUM_MMA_THREADS; } else { pipeline_params.producer_arv_count = NUM_COPY_THREADS; pipeline_params.consumer_arv_count = NUM_MMA_THREADS; } PipelineParamsQ pipeline_params_q; pipeline_params_q.role = warp_group_idx == 0 ? MainloopPipelineQ::ThreadCategory::Producer : MainloopPipelineQ::ThreadCategory::Consumer; pipeline_params_q.producer_arv_count = NUM_COPY_THREADS; pipeline_params_q.consumer_arv_count = cutlass::NumThreadsPerWarpGroup; // just one wg qk MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q); MainloopPipeline pipeline_kv = [&] { if constexpr (use_tma_load_kv) { pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesKV; return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params, /*cluster_shape=*/Shape<_1, _1, _1>{}); } else { return MainloopPipeline(shared_storage.pipeline_kv, pipeline_params); } }(); __syncthreads(); CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue; if (warp_group_idx == 0) { // producer if constexpr(USE_REG_EALLOC) { cutlass::arch::warpgroup_reg_dealloc<72>(); } const uint32_t warp_idx_in_warpgroup = __shfl_sync(0xffffffff, warp_idx % 4, 0); PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); PipelineState smem_pipe_write_kv = cutlass::make_producer_start_state(); for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { const int bid = mainloop_params.batch_ids[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); // load Q collective_mainloop.load_q( mainloop_params, pipeline_q, smem_pipe_write_q, shared_storage, threadIdx.x, bid); if constexpr (!use_tma_load_kv) { // load kv collective_mainloop.load_kv( mainloop_params, pipeline_kv, smem_pipe_write_kv, shared_storage, bid, seq_len_decoder_now, tile_id ); } else { if (warp_idx_in_warpgroup == 0) { // load kv tma collective_mainloop.load_kv_tma( mainloop_params, pipeline_kv, smem_pipe_write_kv, shared_storage, bid, seq_len_decoder_now, tile_id ); } } } } else { // consumer if constexpr(USE_REG_EALLOC) { cutlass::arch::warpgroup_reg_alloc<216>(); } PipelineStateQ smem_pipe_read_q; PipelineState smem_pipe_read_kv; typename Ktraits::TiledMmaPVSS tiled_mma_pv; Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{})); auto attention_updater = OnlineSoftmax<2 * size<1>(tOrO), /*WITH_SCALE=*/true>(mainloop_params.sm_scale); for (int i = blockIdx.x; i < num_blocks_x; i += SM_COUNT) { clear(tOrO); clear(attention_updater.scores_scale); const int bid = mainloop_params.batch_ids[i]; const int tile_id = mainloop_params.tile_ids_per_batch[i]; const int seq_len_now = mainloop_params.seq_lens_this_time[bid]; const int seq_len_decoder_now = mainloop_params.seq_lens_decoder[bid] + seq_len_now; const int start_token_idx = mainloop_params.cumsum_q_seqlens[bid]; cutlass::arch::NamedBarrier::sync(Ktraits::NUM_THREADS, /*id=*/static_cast(NamedBarriers::kWG0WG1WG2Sync)); if constexpr (BLOCK_SHAPE_KV == 64) { mma_f16( mainloop_params, pipeline_q, smem_pipe_read_q, pipeline_kv, smem_pipe_read_kv, tOrO, attention_updater, threadIdx.x - NUM_COPY_THREADS, bid, seq_len_decoder_now, seq_len_now, tile_id, shared_storage); } else if (BLOCK_SHAPE_KV == 32) { mma_f16_two_stages( mainloop_params, pipeline_q, smem_pipe_read_q, pipeline_kv, smem_pipe_read_kv, tOrO, attention_updater, threadIdx.x - NUM_COPY_THREADS, bid, seq_len_decoder_now, seq_len_now, tile_id, shared_storage); } collective_epilogue.store( epilogue_params, tOrO, attention_updater.get_lse(), shared_storage, tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, bid, mainloop_params.bsz, seq_len_now, start_token_idx, tile_id, seq_len_decoder_now, chunk_size, mainloop_params.max_draft_token_num, mainloop_params.o_stride_bsz); } } } template cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; using IdType = typename KernelTraits::IdType; using NV_TYPE = typename KernelTraits::NV_TYPE; using CollectiveMainloop = CollectiveMainloop; using CollectiveEpilogue = CollectiveEpilogue; typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.qk_head_dim), make_stride(params.qk_head_dim, _1{})), // layout q make_layout(make_shape(params.block_size, params.qk_head_dim, params.max_block_num), make_stride(params.qk_head_dim, _1{}, params.block_size * params.qk_head_dim)), make_layout(make_shape(params.chunk_num, params.bsz * params.max_draft_token_num * params.q_num_head), make_stride(params.bsz * params.max_draft_token_num * params.q_num_head, _1{})), params.Q, params.KV, params.m, params.d, params.block_tables, params.seq_lens_this_time, params.seq_lens_decoder, params.cumsum_q_seqlens, params.batch_ids, params.tile_ids_per_batch, params.num_blocks_x, params.chunk_size_device, params.sm_scale, params.bsz, params.max_block_num, params.max_block_num_per_seq, params.q_stride_bsz, params.q_stride_head_num, params.kv_stride_block_num, params.kv_stride_block_size, params.o_stride_bsz, params.o_stride_head_num, params.chunk_num, params.max_draft_token_num }); typename CollectiveEpilogue::Params epilogue_params = CollectiveEpilogue::to_underlying_arguments_ntma({ params.O, make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})), // layout O params.O_tmp, make_layout(make_shape(KernelTraits::BLOCK_SHAPE_Q, params.vo_head_dim), make_stride(params.vo_head_dim, _1{})) // layout O_tmp }); // Get the ptr to kernel function. auto kernel = MLAWithKVCacheKernel; int smem_size = sizeof(typename KernelTraits::SharedStorage); cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); int device; cudaGetDevice(&device); int multiprocessor_count; cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device); int act_blocks_per_sm; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &act_blocks_per_sm, kernel, KernelTraits::NUM_WARPS * 32, smem_size); // NOTE: (changwenbin) Here the grid size is fixed so that MLA can be captured // by the graph. dim3 grid_dims = {multiprocessor_count, 1, 1}; static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize, 1, 1); kernel<<>>( mainloop_params, epilogue_params ); if (params.chunk_num > 1) { constexpr int vec_size = 16 / sizeof(DTypeO); constexpr int merge_block_size = 256; constexpr int blockx = KernelTraits::HEAD_DIM_VO / vec_size; constexpr int blocky = (merge_block_size + blockx - 1) / blockx; dim3 grids_merge(multiprocessor_count, params.q_num_head); // 128k is too large dim3 blocks_merge(blockx, blocky); merge_multi_chunks_kernel <<>>( reinterpret_cast(params.O_tmp), params.m, params.d, params.seq_lens_this_time, params.seq_lens_decoder, params.cumsum_q_seqlens, params.batch_id_per_token, params.chunk_size_device, reinterpret_cast(params.O), params.q_num_head, params.vo_head_dim, params.token_num, params.bsz, params.max_draft_token_num); } return cudaSuccess; } template cudaError_t BatchMLAWithPagedKVCacheDispatched(Params& params, cudaStream_t stream) { constexpr bool CAUSAL = true; if constexpr (HEAD_DIM_QK == 576) { DISPATCH_GROUP_SIZE(params.q_num_head, GROUP_SIZE, BatchMLAWithPagedKVCacheKernelTraitsDispatched< AttentionKernelTraits, CAUSAL, Params, USE_REG_EALLOC, USE_FIXED_BLOCK>(params, stream);) } else { return cudaErrorNotSupported; } return cudaSuccess; }; } // namespace mla_attn #endif // ATTENTION_HOPPER_PREFILL_SM90_CUH_