diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index 3e40794f4..16f1b223f 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -1,7 +1,3 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ - // Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" -#include "kernel_traits.h" #include "flash_mask_attn_kernel.hpp" +#include "kernel_traits.h" +#include "paddle/extension.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -29,139 +25,165 @@ struct cuteType; template <> struct cuteType { - using type = cutlass::half_t; + using type = cutlass::half_t; }; template <> struct cuteType { - using type = cutlass::bfloat16_t; + using type = cutlass::bfloat16_t; }; template -std::vector DispatchFlashAttentionMask( - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::optional& mask, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_len, - const int max_enc_len_this_time, - const int max_dec_len_this_time) { +void DispatchFlashAttentionMask(const paddle::Tensor& q_input, + const paddle::Tensor& k_input, + const paddle::Tensor& v_input, + const paddle::Tensor& cu_seq_q, + const paddle::Tensor& cu_seq_k, + const paddle::Tensor& seq_len_encoder, + const paddle::Tensor& attn_out, + const paddle::optional& mask, + const int head_num, + const int kv_head_num, + const int head_dim, + const int max_seq_len, + const int max_enc_len_this_time, + const int max_dec_len_this_time) { + constexpr int kBlockM = 128; + constexpr int kBlockN = 128; + const int batch_size = seq_len_encoder.dims()[0]; - constexpr int kBlockM = 128; - constexpr int kBlockN = 128; - const int batch_size = cu_seq_q.dims()[0]; + Flash_mask_params params; + memset(¶ms, 0, sizeof(Flash_mask_params)); - paddle::Tensor out = paddle::empty( - {q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place()); + params.q_ptr = const_cast(q_input.data()); + params.k_ptr = const_cast(k_input.data()); + params.v_ptr = const_cast(v_input.data()); + params.cu_seq_q = const_cast(cu_seq_q.data()); + params.cu_seq_k = const_cast(cu_seq_k.data()); + params.seq_len_encoder = const_cast(seq_len_encoder.data()); + params.head_num = head_num; + params.kv_head_num = kv_head_num; + params.max_seq_len_q = max_enc_len_this_time; + params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time; + params.batch_size = batch_size; + params.gqa_group_size = head_num / kv_head_num; + constexpr float kLog2e = 1.4426950408889634074; + params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e; - Flash_mask_params params; - memset(¶ms, 0, sizeof(Flash_mask_params)); + using cute_type = typename cuteType::type; - params.q_ptr = const_cast(q_input.data()); - params.k_ptr = const_cast(k_input.data()); - params.v_ptr = const_cast(v_input.data()); - params.o_ptr = const_cast(out.data()); - params.cu_seq_q = const_cast(cu_seq_q.data()); - params.cu_seq_k = const_cast(cu_seq_k.data()); - params.seq_len_encoder = const_cast(seq_len_encoder.data()); - params.head_num = head_num; - params.kv_head_num = kv_head_num; - params.max_seq_len_q = max_enc_len_this_time; - params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time; - params.batch_size = batch_size; - params.gqa_group_size = head_num / kv_head_num; - constexpr float kLog2e = 1.4426950408889634074; - params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e; - - using cute_type = typename cuteType::type; - - if (mask) { - params.mask = const_cast(mask.get().data()); - flash_attn_headdim128(params, 0); - } else { - flash_attn_headdim128(params, 0); + if (mask) { + params.mask = const_cast(mask.get().data()); + if (attn_out.dtype() == paddle::DataType::FLOAT16) { + using out_type = phi::dtype::float16; + params.o_ptr = const_cast(attn_out.data()); + flash_attn_headdim128::type>( + params, q_input.stream()); + } else if (attn_out.dtype() == paddle::DataType::BFLOAT16) { + using out_type = phi::dtype::bfloat16; + params.o_ptr = const_cast(attn_out.data()); + flash_attn_headdim128::type>( + params, q_input.stream()); } + } else { + if (attn_out.dtype() == paddle::DataType::FLOAT16) { + using out_type = phi::dtype::float16; + params.o_ptr = const_cast(attn_out.data()); + flash_attn_headdim128::type>( + params, q_input.stream()); + } else if (attn_out.dtype() == paddle::DataType::BFLOAT16) { + using out_type = phi::dtype::bfloat16; + params.o_ptr = const_cast(attn_out.data()); + flash_attn_headdim128::type>( + params, q_input.stream()); + } + } - return {out}; + // cudaDeviceSynchronize(); + // auto err = cudaGetLastError(); + // printf("mask attn err = %d, str = %s\n", err, cudaGetErrorString(err)); } - -std::vector FlashAttentionMask( - const paddle::Tensor& q_input, - const paddle::Tensor& k_input, - const paddle::Tensor& v_input, - const paddle::Tensor& cu_seq_q, - const paddle::Tensor& cu_seq_k, - const paddle::Tensor& seq_len_encoder, - const paddle::optional &mask, - const int head_num, - const int kv_head_num, - const int head_dim, - const int max_seq_len, - const int max_enc_len_this_time, - const int max_dec_len_this_time) { - - if (q_input.dtype() == paddle::DataType::FLOAT16) { - using T = phi::dtype::float16; - return std::move( - DispatchFlashAttentionMask( - q_input, - k_input, - v_input, - cu_seq_q, - cu_seq_k, - seq_len_encoder, - mask, - head_num, - kv_head_num, - head_dim, - max_seq_len, - max_enc_len_this_time, - max_dec_len_this_time)); - } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { - using T = phi::dtype::bfloat16; - return std::move( - DispatchFlashAttentionMask( - q_input, - k_input, - v_input, - cu_seq_q, - cu_seq_k, - seq_len_encoder, - mask, - head_num, - kv_head_num, - head_dim, - max_seq_len, - max_enc_len_this_time, - max_dec_len_this_time)); - } - +void FlashAttentionMask(const paddle::Tensor& q_input, + const paddle::Tensor& k_input, + const paddle::Tensor& v_input, + const paddle::Tensor& cu_seq_q, + const paddle::Tensor& cu_seq_k, + const paddle::Tensor& seq_len_encoder, + const paddle::Tensor& attn_out, + const paddle::optional& mask, + const int head_num, + const int kv_head_num, + const int head_dim, + const int max_seq_len, + const int max_enc_len_this_time, + const int max_dec_len_this_time) { + if (q_input.dtype() == paddle::DataType::FLOAT16) { + using T = phi::dtype::float16; + DispatchFlashAttentionMask(q_input, + k_input, + v_input, + cu_seq_q, + cu_seq_k, + seq_len_encoder, + attn_out, + mask, + head_num, + kv_head_num, + head_dim, + max_seq_len, + max_enc_len_this_time, + max_dec_len_this_time); + } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { + using T = phi::dtype::bfloat16; + DispatchFlashAttentionMask(q_input, + k_input, + v_input, + cu_seq_q, + cu_seq_k, + seq_len_encoder, + attn_out, + mask, + head_num, + kv_head_num, + head_dim, + max_seq_len, + max_enc_len_this_time, + max_dec_len_this_time); + } } - PD_BUILD_STATIC_OP(flash_attention_mask) - .Inputs({ - "q_input", - "k_input", - "v_input", - "cu_seq_q", - "cu_seq_k", - "seq_len_encoder", - paddle::Optional("mask")}) - .Attrs({ - "head_num: int", - "kv_head_num: int", - "head_dim: int", - "max_seq_len: int", - "max_enc_len_this_time: int", - "max_dec_len_this_time: int"}) - .Outputs({ - "out"}) + .Inputs({"q_input", + "k_input", + "v_input", + "cu_seq_q", + "cu_seq_k", + "seq_len_encoder", + "attn_out", + paddle::Optional("mask")}) + .Attrs({"head_num: int", + "kv_head_num: int", + "head_dim: int", + "max_seq_len: int", + "max_enc_len_this_time: int", + "max_dec_len_this_time: int"}) + .Outputs({"out"}) + .SetInplaceMap({{"attn_out", "out"}}) .SetKernelFn(PD_KERNEL(FlashAttentionMask)); diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp index 0d7a00db9..d07e780fb 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp @@ -1,18 +1,20 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "cute/algorithm/copy.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/layout/layout.h" -#include "cutlass/numeric_types.h" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/cluster_launch.hpp" #include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cluster_launch.hpp" #include "kernel_traits.h" #include "mainloop_attn.hpp" @@ -22,210 +24,247 @@ using namespace cute; template auto get_gmem_layout(int token_num, int head_num) { - return make_layout( - make_shape(token_num, kHeadDim, head_num), - make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim)); + return make_layout(make_shape(token_num, kHeadDim, head_num), + make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim)); } template -__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) +__global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, + 1) compute_attn_ws( - CUTE_GRID_CONSTANT typename CollectiveMainloopAttn::Params const mainloop_params, + CUTE_GRID_CONSTANT + typename CollectiveMainloopAttn::Params const mainloop_params, CUTE_GRID_CONSTANT Flash_mask_params const data_params) { + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using output_type = typename Ktraits::output_type; + using SoftType = ElementAccum; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; - using Element = typename Ktraits::Element; - using ElementAccum = typename Ktraits::ElementAccum; - using SoftType = ElementAccum; - using TileShape_MNK = typename Ktraits::TileShape_MNK; - using ClusterShape = typename Ktraits::ClusterShape_MNK; + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + constexpr int kHeadDim = Ktraits::kHeadDim; + constexpr bool NeedMask = Ktraits::NeedMask; - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); - static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = Ktraits::kBlockM; - static constexpr int kBlockN = Ktraits::kBlockN; - constexpr int kHeadDim = Ktraits::kHeadDim; - constexpr bool NeedMask = Ktraits::NeedMask; + using CollectiveMainloop = CollectiveMainloopAttn; - using CollectiveMainloop = CollectiveMainloopAttn; + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; - using MainloopPipeline = typename Ktraits::MainloopPipeline; - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename MainloopPipeline::PipelineState; + extern __shared__ char shared_memory[]; + auto &shared_storage = + *reinterpret_cast(shared_memory); - extern __shared__ char shared_memory[]; - auto &shared_storage = *reinterpret_cast(shared_memory); + __align__(16) __shared__ int mask[kBlockM]; - __align__(16) __shared__ int mask[kBlockM]; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int bidb = blockIdx.z; - const int m_block = blockIdx.x; - const int bidh = blockIdx.y; - const int bidb = blockIdx.z; + if constexpr (NeedMask) { + const int *mask_this_batch = + data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM; - if constexpr (NeedMask) { - const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM; - - for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) { - mask[i] = mask_this_batch[i]; - } + for (int i = threadIdx.x; i < kBlockM; + i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) { + mask[i] = mask_this_batch[i]; } + } - const int seq_len_q = data_params.seq_len_encoder[bidb]; - const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb]; + const int seq_len_q = data_params.seq_len_encoder[bidb]; + const int seq_len_k = + data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb]; - if (m_block * kBlockM >= seq_len_q) { - return; + if (m_block * kBlockM >= seq_len_q) { + return; + } + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + } + + int const warp_group_thread_idx = + threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(1); + } + + MainloopPipeline pipeline_k( + shared_storage.pipeline_k, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_v( + shared_storage.pipeline_v, pipeline_params, ClusterShape{}); + + __syncthreads(); + + CollectiveMainloop collective_mainloop; + + const int real_seq = seq_len_q - m_block * kBlockM; + + const int n_block_max = + NeedMask + ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) + : min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, + kBlockN), + cute::ceil_div(seq_len_k, kBlockN)); + ; + + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + + int warp_idx_in_warpgroup = + __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0) { // Load Q, K, V + PipelineState smem_pipe_write_k = + cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = + cutlass::make_producer_start_state(); + + collective_mainloop.load(mainloop_params, + pipeline_k, + pipeline_v, + smem_pipe_write_k, + smem_pipe_write_v, + shared_storage, + n_block_max, + m_block, + bidh, + bidb, + data_params.cu_seq_q, + data_params.cu_seq_k, + seq_len_q, + seq_len_k); } + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + typename Ktraits::TiledMma1 tiled_mma1; - int const lane_predicate = cute::elect_one_sync(); - int const warp_idx = cutlass::canonical_warp_idx_sync(); + collective_mainloop.mma_init(); - if (warp_idx == 0 && lane_predicate) { - CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); - } + PipelineState smem_pipe_read_k, smem_pipe_read_v; - int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + Tensor tOrO = + partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; - PipelineParams pipeline_params; - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; - int warp_group_idx = cutlass::canonical_warp_group_idx(); - pipeline_params.role = warp_group_idx == 0 - ? MainloopPipeline::ThreadCategory::Producer - : MainloopPipeline::ThreadCategory::Consumer; - pipeline_params.is_leader = warp_group_thread_idx == 0; - pipeline_params.num_consumers = NumMmaThreads; + collective_mainloop.mma(mainloop_params, + pipeline_k, + pipeline_v, + smem_pipe_read_k, + smem_pipe_read_v, + tOrO, + softmax, + mask, + n_block_max, + threadIdx.x - NumCopyThreads, + m_block, + seq_len_q, + seq_len_k, + shared_storage); - if (warp_idx == 0 && lane_predicate) { - shared_storage.barrier_Q.init(1); - } - - MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); - MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); - - __syncthreads(); - - CollectiveMainloop collective_mainloop; - - const int real_seq = seq_len_q - m_block * kBlockM; - - const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN); - - if (warp_group_idx == 0) { // Producer - cutlass::arch::warpgroup_reg_dealloc(); - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - if (warp_idx_in_warpgroup == 0) { // Load Q, K, V - PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); - PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); - - collective_mainloop.load( - mainloop_params, - pipeline_k, - pipeline_v, - smem_pipe_write_k, - smem_pipe_write_v, - shared_storage, - n_block_max, - m_block, - bidh, - bidb, - data_params.cu_seq_q, - data_params.cu_seq_k, - seq_len_q, - seq_len_k); - } - } else { // Consumer - cutlass::arch::warpgroup_reg_alloc(); - typename Ktraits::TiledMma1 tiled_mma1; - - PipelineState smem_pipe_read_k, smem_pipe_read_v; - - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); - Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; - - collective_mainloop.mma( - mainloop_params, - pipeline_k, - pipeline_v, - smem_pipe_read_k, - smem_pipe_read_v, - tOrO, - softmax, - mask, - n_block_max, - threadIdx.x - NumCopyThreads, - m_block, - seq_len_q, - seq_len_k, - shared_storage); - - const int o_head_stride = data_params.head_num * kHeadDim; - const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim; - - collective_mainloop.store( - mainloop_params, - tOrO, - shared_storage, - tiled_mma1, - threadIdx.x - NumCopyThreads, - o_head_stride, - real_seq, - reinterpret_cast(data_params.o_ptr) + store_offset); - } + const int o_head_stride = data_params.head_num * kHeadDim; + const int store_offset = + (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + + bidh * kHeadDim; + collective_mainloop.store( + mainloop_params, + tOrO, + shared_storage, + tiled_mma1, + threadIdx.x - NumCopyThreads, + o_head_stride, + real_seq, + reinterpret_cast(data_params.o_ptr) + store_offset); + } } - -template +template void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) { - using Element = typename Kernel_traits::Element; - using TileShape_MNK = typename Kernel_traits::TileShape_MNK; - using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + using Element = typename Kernel_traits::Element; + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; - using CollectiveMainloop = CollectiveMainloopAttn; - constexpr int kHeadDim = Kernel_traits::kHeadDim; + using CollectiveMainloop = CollectiveMainloopAttn; + constexpr int kHeadDim = Kernel_traits::kHeadDim; - typename CollectiveMainloop::Params mainloop_params = - CollectiveMainloop::to_underlying_arguments({ - static_cast(params.q_ptr), - get_gmem_layout(params.max_seq_len_q, params.head_num), - static_cast(params.k_ptr), - get_gmem_layout(params.max_seq_len_k, params.kv_head_num), - static_cast(params.v_ptr), - get_gmem_layout(params.max_seq_len_k, params.kv_head_num), - params.scale_softmax_log2 - }); + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments( + {static_cast(params.q_ptr), + get_gmem_layout(params.max_seq_len_q * params.batch_size, + params.head_num), + static_cast(params.k_ptr), + get_gmem_layout(params.max_seq_len_k * params.batch_size, + params.kv_head_num), + static_cast(params.v_ptr), + get_gmem_layout(params.max_seq_len_k * params.batch_size, + params.kv_head_num), + params.scale_softmax_log2}); - int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM); + int num_blocks_m = + cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM); - num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * + size<0>(ClusterShape{}); - void *kernel; - kernel = (void *)compute_attn_ws; - int smem_size = sizeof(typename Kernel_traits::SharedStorage); + void *kernel; + kernel = (void *)compute_attn_ws; + int smem_size = sizeof(typename Kernel_traits::SharedStorage); - if (smem_size >= 48 * 1024) { - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } - dim3 grid_dims; - grid_dims.x = num_blocks_m; - grid_dims.y = params.head_num; - grid_dims.z = params.batch_size; + dim3 grid_dims; + grid_dims.x = num_blocks_m; + grid_dims.y = params.head_num; + grid_dims.z = params.batch_size; - static constexpr int ctaSize = Kernel_traits::kNWarps * 32; - dim3 block_dims(ctaSize); - dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); - cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; - cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params); + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + dim3 block_dims(ctaSize); + dim3 cluster_dims(size<0>(ClusterShape{}), + size<1>(ClusterShape{}), + size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{ + grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster( + launch_params, kernel, mainloop_params, params); } -template +template void flash_attn_headdim128(Flash_mask_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + constexpr static int kNWarps = kBlockM / 16 + 4; + constexpr static int kStages = 2; - constexpr static int Headdim = 128; - constexpr static int kNWarps = kBlockM / 16 + 4; - constexpr static int kStages = 2; - - using Ktraits = Flash_mask_kernel_traits; - run_flash_mask(params, stream); + using Ktraits = Flash_mask_kernel_traits; + run_flash_mask(params, stream); } diff --git a/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h b/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h index c1ba9ff47..1a0e2cabd 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h +++ b/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h @@ -1,6 +1,16 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -15,110 +25,155 @@ using namespace cute; struct Flash_mask_params { - void *__restrict__ q_ptr; - void *__restrict__ k_ptr; - void *__restrict__ v_ptr; - void * __restrict__ o_ptr; - int * __restrict__ cu_seq_q; - int * __restrict__ cu_seq_k; - int * __restrict__ mask; - int * seq_len_encoder; - int head_num; - int kv_head_num; - int max_seq_len_q; - int max_seq_len_k; - int batch_size; - int gqa_group_size; - float scale_softmax_log2; + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ o_ptr; + int *__restrict__ cu_seq_q; + int *__restrict__ cu_seq_k; + int *__restrict__ mask; + int *seq_len_encoder; + int head_num; + int kv_head_num; + int max_seq_len_q; + int max_seq_len_k; + int batch_size; + int gqa_group_size; + float scale_softmax_log2; }; -template +template struct SharedStorageQKVO { - cute::array_aligned> smem_q; - cute::array_aligned> smem_k; - union { - cute::array_aligned> smem_v; - cute::array_aligned> smem_o; - }; - struct { - cutlass::arch::ClusterTransactionBarrier barrier_Q; - typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; - typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; - }; + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + }; }; -template +template struct Flash_mask_kernel_traits { - using Element = elem_type; - using ElementAccum = float; - using index_t = int32_t; + using Element = elem_type; + using output_type = out_type; + using ElementAccum = float; + using index_t = int32_t; - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - using TileShape_MNK = Shape, Int, Int>; - using ClusterShape_MNK = Shape, Int<1>, Int<1>>; - static constexpr int kStages = kStages_; - static constexpr int NeedMask = NeedMask_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + using TileShape_MNK = Shape, Int, Int>; + using ClusterShape_MNK = Shape, Int<1>, Int<1>>; + static constexpr int kStages = kStages_; + static constexpr int NeedMask = NeedMask_; - using AtomLayoutMNK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - using TiledMma1 = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(TileShape_MNK{})), - GMMA::Major::K, GMMA::Major::MN>(), - AtomLayoutMNK{})); + using AtomLayoutMNK = Layout, _1, _1>>; + using TiledMma0 = decltype(cute::make_tiled_mma( + cute::GMMA:: + ss_op_selector(), + AtomLayoutMNK{})); + using TiledMma1 = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(TileShape_MNK{})), + GMMA::Major::K, + GMMA::Major::MN>(), + AtomLayoutMNK{})); - using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutAtomQ = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtomK{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutAtomK = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), + shape<2>(TileShape_MNK{}), + Int{}))); - using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutV = - decltype(tile_to_shape(SmemLayoutAtomV{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutAtomV = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + make_shape(shape<1>(TileShape_MNK{}), + shape<2>(TileShape_MNK{}), + Int{}))); - using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutAtomO = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + output_type, + decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); - using SmemCopyAtomQ = Copy_Atom; - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; - using SharedStorage = SharedStorageQKVO; + using SharedStorage = SharedStorageQKVO; - static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; - static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; - static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); - static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; - static_assert(NumMmaThreads % kNumThreadsPerRow == 0); - static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; - using TiledCopyOAtom = cute::Copy_Atom, Element>; - using TiledCopyOThrLayout = decltype(cute::make_layout( - cute::make_shape(Int{}, Int{}), - LayoutRight{})); - using TiledCopyOValLayout = decltype(cute::make_layout( - cute::make_shape(_1{}, Int{}), - LayoutRight{})); - using GmemTiledCopyO = decltype(make_tiled_copy( - TiledCopyOAtom{}, - TiledCopyOThrLayout{}, - TiledCopyOValLayout{} - )); + static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int NumMmaThreads = kNThreads - NumProducerThreads; + static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v); + static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem; + static_assert(NumMmaThreads % kNumThreadsPerRow == 0); + static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow; + using TiledCopyOAtom = + cute::Copy_Atom, output_type>; + using TiledCopyOThrLayout = decltype(cute::make_layout( + cute::make_shape(Int{}, Int{}), + LayoutRight{})); + using TiledCopyOValLayout = decltype(cute::make_layout( + cute::make_shape(_1{}, Int{}), LayoutRight{})); + using GmemTiledCopyO = decltype(make_tiled_copy( + TiledCopyOAtom{}, TiledCopyOThrLayout{}, TiledCopyOValLayout{})); - using MainloopPipeline = typename cutlass::PipelineTmaAsync; - using PipelineState = typename cutlass::PipelineState; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; }; diff --git a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp index 5592cb2f0..070290383 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/mainloop_attn.hpp @@ -1,13 +1,23 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include #include -#include +#include #include +#include #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" @@ -16,416 +26,559 @@ #include "utils.hpp" - using namespace cute; +enum class AttnNamedBarriers { + QueryEmpty = 0, + ValueEmpty = 1, + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + WarpSchedulerWG1 = 4, + WarpSchedulerWG2 = 5, + WarpSchedulerWG3 = 6, +}; + template struct CollectiveMainloopAttn { + using Element = typename Ktraits::Element; + using output_type = typename Ktraits::output_type; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; - using Element = typename Ktraits::Element; - using TileShape_MNK = typename Ktraits::TileShape_MNK; - using ClusterShape = typename Ktraits::ClusterShape_MNK; + static constexpr int kStages = Ktraits::kStages; + static constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr bool NeedMask = Ktraits::NeedMask; - static constexpr int kStages = Ktraits::kStages; - static constexpr int kHeadDim = Ktraits::kHeadDim; - static constexpr int kBlockM = Ktraits::kBlockM; - static constexpr int kBlockN = Ktraits::kBlockN; - static constexpr bool NeedMask = Ktraits::NeedMask; + using ShapeT = cute::Shape; + using StrideT = cute::Shape; + using LayoutT = cute::Layout; - using ShapeT = cute::Shape; - using StrideT = cute::Shape; - using LayoutT = cute::Layout; + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = + decltype(cutlass::gemm::collective::detail:: + sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO; + using SmemLayoutAtomQ = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); - using GmemTiledCopyQ = cute::SM90_TMA_LOAD; - using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); - using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO; + using SmemLayoutAtomK = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GMMA::Major::K, + Element, + decltype(cute::get<1>(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), + shape<2>(TileShape_MNK{}), + Int{}))); + using SmemLayoutV = SmemLayoutK; + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVt = decltype(cute::composition( + SmemLayoutV{}, + make_layout( + make_shape( + get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), + make_stride(get<1>(TileShape_MNK{}), + _1{}, + Int{})))); + using SmemLayoutO = typename Ktraits::SmemLayoutO; + using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO; - using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), + StrideT{}), + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for Q - using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtomK{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - using SmemLayoutV = SmemLayoutK; - // Note this is the transpose in terms of the view, not in terms of memory. - using SmemLayoutVt = - decltype(cute::composition(SmemLayoutV{}, - make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), - make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); - using SmemLayoutO = typename Ktraits::SmemLayoutO; - using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO; + using TMA_KV = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(StrideT{}, int32_t(0)), + StrideT{}), + take<0, 2>(SmemLayoutK{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - using TMA_Q = decltype(make_tma_copy( - GmemTiledCopyQ{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - repeat_like(StrideT{}, int32_t(0)), - StrideT{} - ), - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), - _1{})); // no mcast for Q + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; - using TMA_KV = decltype(make_tma_copy( + // Set the bytes transferred in this TMA transaction (may involve multiple + // issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast( + size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast( + size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; + + // Host side kernel arguments + struct Arguments { + Element const* ptr_Q; + LayoutT layout_Q; + Element const* ptr_K; + LayoutT layout_K; + Element const* ptr_V; + LayoutT layout_V; + float const softmax_scale_log2; + }; + + // Device side kernel params + struct Params { + LayoutT layout_Q; + LayoutT layout_K; + LayoutT layout_V; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_Q tma_load_Q; + TMA_KV tma_load_K, tma_load_V; + float const softmax_scale_log2; + }; + + static Params to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q); + TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{}); + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); + TMA_KV tma_load_K = make_tma_copy( GmemTiledCopyKV{}, - make_tensor( - make_gmem_ptr(static_cast(nullptr)), - repeat_like(StrideT{}, int32_t(0)), - StrideT{} - ), - take<0, 2>(SmemLayoutK{}), + mK, + SmemLayoutK{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); + TMA_KV tma_load_V = make_tma_copy( + GmemTiledCopyKV{}, + mV, + SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {args.layout_Q, + args.layout_K, + args.layout_V, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), + get<2>(args.layout_K.shape()))), + tma_load_Q, + tma_load_K, + tma_load_V, + args.softmax_scale_log2}; + } - static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); - using MainloopPipeline = typename Ktraits::MainloopPipeline; - using PipelineParams = typename MainloopPipeline::Params; - using PipelineState = typename MainloopPipeline::PipelineState; + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best + /// performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor( + mainloop_params.tma_load_V.get_tma_descriptor()); + } - // Set the bytes transferred in this TMA transaction (may involve multiple issues) - static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); - static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + template + CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor, + const Shape& tile_shape, + const int* cu_seq_len, + const int bidh, + const int bidb, + const int actual_seq_len) const { + auto g_offset = local_tile(m_tensor(_, _, bidh), + cute::make_shape(1, get<1>(tile_shape)), + make_coord(cu_seq_len[bidb], _0{})); + auto g_sequence = make_tensor( + g_offset.data(), + make_layout(cute::make_shape(actual_seq_len, get<1>(tile_shape)), + g_offset.stride())); + auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); + return g_tensor; + } - static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; + template + CUTLASS_DEVICE void load(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, + SharedStorage& shared_storage, + const int n_block_max, + const int m_block, + const int bidh, + const int bidb, + const int* cu_seq_q, + const int* cu_seq_k, + const int seq_len_q, + const int seq_len_k) { + Tensor sQ = + make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = + make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); - // Host side kernel arguments - struct Arguments { - Element const* ptr_Q; - LayoutT layout_Q; - Element const* ptr_K; - LayoutT layout_K; - Element const* ptr_V; - LayoutT layout_V; - float const softmax_scale_log2; + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor( + mainloop_params.layout_Q.shape()); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor( + mainloop_params.layout_K.shape()); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor( + mainloop_params.layout_V.shape()); + int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + + Tensor gQ = get_local_tile_tensor( + mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)( + _, _, m_block); + Tensor gK = get_local_tile_tensor( + mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); + Tensor gV = get_local_tile_tensor( + mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); + + Tensor sQ_x = + make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = + make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, + _0{}, + Layout<_1>{}, + group_modes<0, 2>(sQ_x), + group_modes<0, 2>(gQ_x)); + auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, + _0{}, + Layout<_1>{}, + group_modes<0, 2>(sK), + group_modes<0, 2>(gK)); + auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, + _0{}, + Layout<_1>{}, + group_modes<0, 2>(sV), + group_modes<0, 2>(gV)); + + uint16_t mcast_mask_kv = 0; + + int n_block = n_block_max - 1; + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with( + reinterpret_cast< + cutlass::arch::ClusterTransactionBarrier::ValueType&>( + shared_storage.barrier_Q), + 0 /*mcast_mask*/), + tQgQ, + tQsQ); + } + + if (lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with( + *pipeline_k.producer_get_barrier(smem_pipe_write_k), + mcast_mask_kv), + tKgK(_, n_block), + tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + } + + if (lane_predicate) { +#pragma unroll 2 + for (; n_block > 0; --n_block) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with( + *pipeline_k.producer_get_barrier(smem_pipe_write_k), + mcast_mask_kv), + tKgK(_, n_block - 1), + tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with( + *pipeline_v.producer_get_barrier(smem_pipe_write_v), + mcast_mask_kv), + tVgV(_, n_block), + tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + if (lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with( + *pipeline_v.producer_get_barrier(smem_pipe_write_v), + mcast_mask_kv), + tVgV(_, n_block), + tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + + CUTLASS_DEVICE void warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + + cutlass::canonical_warp_group_idx() /*id*/); + } + } + + CUTLASS_DEVICE void mma_init() { + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || + NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); + } + if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + + 2 /*id*/); + } + } + } + + CUTLASS_DEVICE void warp_scheduler_barrier_arrive() { + if constexpr (!UseSchedulerBarrier) { + return; + } + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || + NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + + (cutlass::canonical_warp_group_idx() <= 2 + ? cutlass::canonical_warp_group_idx() + 1 + : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads, + static_cast(AttnNamedBarriers::WarpSchedulerWG1) - 1 + + (cutlass::canonical_warp_group_idx() <= 1 + ? cutlass::canonical_warp_group_idx() + 2 + : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } + } + + template + CUTLASS_DEVICE void mma(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_read_k, + PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, + Softmax& softmax, + const int* mask, + const int n_block_max, + const int thread_idx, + const int m_block, + const int seq_len_q, + const int seq_len_k, + SharedStorage& shared_storage) { + Tensor sQ = + make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), + SmemLayoutVt{}); + + typename Ktraits::TiledMma0 tiled_mma0; + typename Ktraits::TiledMma1 tiled_mma1; + auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); + auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); + + Tensor tSrQ = threadMma0.partition_fragment_A(sQ); + Tensor tSrK = threadMma0.partition_fragment_B(sK); + Tensor tOrV = threadMma1.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - // Device side kernel params - struct Params { - LayoutT layout_Q; - LayoutT layout_K; - LayoutT layout_V; - cutlass::FastDivmod qhead_per_khead_divmod; - TMA_Q tma_load_Q; - TMA_KV tma_load_K, tma_load_V; - float const softmax_scale_log2; - }; + tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + int n_block = n_block_max - 1; - static Params - to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q); - TMA_Q tma_load_Q = make_tma_copy( - GmemTiledCopyQ{}, - mQ, - SmemLayoutQ{}, - select<0, 2>(TileShape_MNK{}), - _1{}); - Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K); - TMA_KV tma_load_K = make_tma_copy( - GmemTiledCopyKV{}, - mK, - SmemLayoutK{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V); - TMA_KV tma_load_V = make_tma_copy( - GmemTiledCopyKV{}, - mV, - SmemLayoutV{}(_, _, _0{}), - select<1, 2>(TileShape_MNK{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {args.layout_Q, args.layout_K, args.layout_V, - cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))), - tma_load_Q, tma_load_K, tma_load_V, - args.softmax_scale_log2}; + cutlass::ConsumerToken barrier_token = static_cast( + shared_storage.barrier_Q.try_wait(0)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { + shared_storage.barrier_Q.wait(0); } - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); + Tensor tSrS = + partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + gemm( + tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + warp_scheduler_barrier_arrive(); + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + + int mask_start_idx; + int mask_row_id; + int col_base; + + if constexpr (NeedMask) { + const int lane_id = thread_idx % 32; + mask_start_idx = mask[0] / kBlockN - 1; + + mask_row_id = thread_idx / 32 * 16 + lane_id / 4; + + col_base = thread_idx % 4 * 2; + + app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN); + } else { + auto col_limit_causal = [&](int row, int n_block) { + return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + + m_block * kBlockM; + }; + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if (int(get<1>(tScS(i))) >= + std::min(seq_len_k - n_block * kBlockN, + col_limit_causal(int(get<0>(tScS(i))), n_block))) { + tSrS(i) = -INFINITY; + } + } } - template - CUTLASS_DEVICE auto get_local_tile_tensor( - const MTensor &m_tensor, - const Shape &tile_shape, - const int *cu_seq_len, - const int bidh, - const int bidb, - const int actual_seq_len) const { - auto g_offset = local_tile( - m_tensor(_, _, bidh), - cute::make_shape(1, get<1>(tile_shape)), - make_coord(cu_seq_len[bidb], _0{})); - auto g_sequence = make_tensor( - g_offset.data(), - make_layout( - cute::make_shape(actual_seq_len, get<1>(tile_shape)), - g_offset.stride() - )); - auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{})); - return g_tensor; + softmax.template online_softmax( + tSrS, mainloop_params.softmax_scale_log2); + + Tensor tOrP = make_tensor( + convert_type(tSrS).data(), + convert_layout_acc_Aregs(tSrS.layout())); + Tensor scores_scale = make_fragment_like(softmax.row_max); + clear(scores_scale); + +#pragma unroll 2 + for (; n_block > 0; --n_block) { + Tensor tSrS = + partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + + if constexpr (NeedMask) { + if (n_block >= mask_start_idx) { + app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN); + } + } + + gemm( + tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); + gemm( + tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + cute::copy(softmax.template max( + tSrS, mainloop_params.softmax_scale_log2), + scores_scale); + softmax.template online_softmax( + tSrS, mainloop_params.softmax_scale_log2); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy( + make_tensor(convert_type(tSrS).data(), + convert_layout_acc_Aregs( + tSrS.layout())), + tOrP); } + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); - template - CUTLASS_DEVICE void - load(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, - PipelineState& smem_pipe_write_k, - PipelineState& smem_pipe_write_v, - SharedStorage &shared_storage, - const int n_block_max, - const int m_block, - const int bidh, - const int bidb, - const int *cu_seq_q, - const int *cu_seq_k, - const int seq_len_q, - const int seq_len_k) { + gemm( + tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), + scores_scale); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); + ++smem_pipe_read_v; - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + softmax.rescale_o(tOrO, scores_scale); + } - Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); - Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape()); - Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape()); - int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + template + CUTLASS_DEVICE void store(Params const& mainloop_params, + FrgTensorO const& tOrO, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + const int o_head_stride, + const int real_seq, + T* out_ptr) { + Tensor sO = + make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor gQ = get_local_tile_tensor( - mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block); - Tensor gK = get_local_tile_tensor( - mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); - Tensor gV = get_local_tile_tensor( - mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k); + Tensor tOrO_out = convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); - Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); - Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); - auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); - auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK)); - auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV)); + cutlass::arch::NamedBarrier::sync( + NumMmaThreads, static_cast(AttnNamedBarriers::ValueEmpty) /*id*/); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible + // to TMA + cutlass::arch::NamedBarrier::arrive( + NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); - uint16_t mcast_mask_kv = 0; + Tensor gO = make_tensor(make_gmem_ptr(out_ptr), + Shape, Int>{}, + make_stride(o_head_stride, _1{})); - int n_block = n_block_max - 1; + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - int lane_predicate = cute::elect_one_sync(); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - if (lane_predicate) { - shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); - } + Tensor cO = make_identity_tensor(Shape, Int>{}); + Tensor tOcO = gmem_thr_copy_O.partition_S(cO); - if (lane_predicate) { - pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), - tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); - ++smem_pipe_write_k; - } - - if (lane_predicate) { - #pragma unroll 2 - for (; n_block > 0; --n_block) { - pipeline_k.producer_acquire(smem_pipe_write_k); - copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), - tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index())); - ++smem_pipe_write_k; - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - ++smem_pipe_write_v; - } - } - if (lane_predicate) { - pipeline_v.producer_acquire(smem_pipe_write_v); - copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), - tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); - ++smem_pipe_write_v; - } + if (real_seq >= kBlockM) { + copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO); + } else { + copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq); } - - template - CUTLASS_DEVICE void - mma(Params const& mainloop_params, - MainloopPipeline pipeline_k, - MainloopPipeline pipeline_v, - PipelineState& smem_pipe_read_k, - PipelineState& smem_pipe_read_v, - FrgTensorO& tOrO, - Softmax& softmax, - const int *mask, - const int n_block_max, - const int thread_idx, - const int m_block, - const int seq_len_q, - const int seq_len_k, - SharedStorage& shared_storage) { - - Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); - Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); - - typename Ktraits::TiledMma0 tiled_mma0; - typename Ktraits::TiledMma1 tiled_mma1; - auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); - auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); - - Tensor tSrQ = threadMma0.partition_fragment_A(sQ); - Tensor tSrK = threadMma0.partition_fragment_B(sK); - Tensor tOrV = threadMma1.partition_fragment_B(sVt); - - auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - }; - - tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; - - int n_block = n_block_max - 1; - - cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(0)); - if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); } - - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - warpgroup_wait<0>(); - pipeline_k.consumer_release(smem_pipe_read_k); - ++smem_pipe_read_k; - - int mask_start_idx; - int mask_row_id; - int col_base; - - if constexpr (NeedMask) { - const int lane_id = thread_idx % 32; - mask_start_idx = mask[0] / kBlockN - 1; - - mask_row_id = thread_idx / 32 * 16 + lane_id / 4; - - col_base = thread_idx % 4 * 2; - - app_mask( - tSrS, - mask, - mask_row_id, - col_base + n_block * kBlockN); - } else { - auto col_limit_causal = [&](int row, int n_block) { - return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM; - }; - Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); - Tensor tScS = threadMma0.partition_C(cS); - #pragma unroll - for (int i = 0; i < size(tSrS); ++i) { - if (int(get<1>(tScS(i))) >= - std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) { - tSrS(i) = -INFINITY; - } - } - } - - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - - Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); - Tensor scores_scale = make_fragment_like(softmax.row_max); - clear(scores_scale); - - #pragma unroll 1 - for (; n_block > 0; --n_block) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); - consumer_wait(pipeline_k, smem_pipe_read_k); - - if constexpr (NeedMask) { - if (n_block >= mask_start_idx) { - app_mask( - tSrS, - mask, - mask_row_id, - col_base + n_block * kBlockN); - } - } - - gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_v, smem_pipe_read_v); - gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - warpgroup_wait<1>(); - pipeline_k.consumer_release(smem_pipe_read_k); // release K - cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); - softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V - ++smem_pipe_read_k; - ++smem_pipe_read_v; - cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); - } - - softmax.rescale_o(tOrO, scores_scale); - consumer_wait(pipeline_v, smem_pipe_read_v); - - gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); - cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); - ++smem_pipe_read_v; - - softmax.rescale_o(tOrO, scores_scale); - return; - } - - template - CUTLASS_DEVICE void - store(Params const& mainloop_params, - FrgTensorO const& tOrO, - SharedStorage& shared_storage, - TiledMma tiled_mma, - int thread_idx, - const int o_head_stride, - const int real_seq, - T * out_ptr) { - - Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); - auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor tOrO_out = convert_type(tOrO); - Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); - - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - - cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - - Tensor gO = make_tensor(make_gmem_ptr(out_ptr), - Shape, Int>{}, - make_stride(o_head_stride, _1{})); - - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); - Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - - Tensor cO = make_identity_tensor(Shape, Int>{}); - - Tensor tOcO = gmem_thr_copy_O.partition_S(cO); - - if (real_seq >= kBlockM) { - copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO); - } else { - copy(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq); - } - } - + } }; diff --git a/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp b/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp index 5e7fd00b8..386f67d8c 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp @@ -1,6 +1,16 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,195 +22,245 @@ #include "utils.hpp" - using namespace cute; - -template +template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ __forceinline__ T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } }; -template<> +template <> struct Allreduce<2> { -template -static __device__ __forceinline__ T run(T x, Operator &op) { + template + static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; -} + } }; -template -__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); mi++) { - summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - summary(mi) = op(summary(mi), tensor(mi, ni)); - } +template +__device__ __forceinline__ void thread_reduce_( + Tensor const &tensor, + Tensor &summary, + Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); } + } } -template -__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { - CUTE_STATIC_ASSERT_V(size(dst) == size(src)); - #pragma unroll - for (int i = 0; i < size(dst); i++){ - dst(i) = Allreduce<4>::run(src(i), op); - } +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, + Tensor &src, + Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } } -template -__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { - thread_reduce_(tensor, summary, op); - quad_allreduce_(summary, summary, op); +template +__device__ __forceinline__ void reduce_(Tensor const &tensor, + Tensor &summary, + Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); } -template -__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ - MaxOp max_op; - reduce_(tensor, max, max_op); +template +__device__ __forceinline__ void reduce_max( + Tensor const &tensor, Tensor &max) { + MaxOp max_op; + reduce_(tensor, max, max_op); } -template -__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ - SumOp sum_op; - thread_reduce_(tensor, sum, sum_op); - if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +template +__device__ __forceinline__ void reduce_sum( + Tensor const &tensor, Tensor &sum) { + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { + quad_allreduce_(sum, sum, sum_op); + } } __forceinline__ __device__ __half2 half_exp(__half2 x) { - uint32_t tmp_out, tmp_in; - tmp_in = reinterpret_cast(x); - asm ("ex2.approx.f16x2 %0, %1;\n" - : "=r"(tmp_out) - : "r"(tmp_in)); - __half2 out = reinterpret_cast<__half2&>(tmp_out); - return out; + uint32_t tmp_out, tmp_in; + tmp_in = reinterpret_cast(x); + asm("ex2.approx.f16x2 %0, %1;\n" : "=r"(tmp_out) : "r"(tmp_in)); + __half2 out = reinterpret_cast<__half2 &>(tmp_out); + return out; } // Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); - #pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; - sum(mi) = 0; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } +template +__forceinline__ __device__ void max_scale_exp2_sum( + Tensor &tensor, + Tensor &max, + Tensor &sum, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); } + max(mi) = Allreduce<4>::run(max(mi), max_op); + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + } } - -template -__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const float max_scaled = max(mi) * scale; - #pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - } +template +__forceinline__ __device__ void scale_apply_exp2( + Tensor &tensor, + Tensor const &max, + const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const float max_scaled = max(mi) * scale; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); } + } } - template struct Softmax { + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; - using TensorT = decltype(make_tensor(Shape>{})); - TensorT row_max, row_sum; + CUTLASS_DEVICE Softmax() {}; - CUTLASS_DEVICE Softmax() {}; + template + __forceinline__ __device__ TensorT max(Tensor0 &acc_s, + float softmax_scale_log2) { + Tensor scores = + make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + reduce_max(scores, row_max); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = row_max(mi); + scores_scale(mi) = + exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; - template - __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) { - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scores_scale; - if constexpr (Is_first) { - reduce_max(scores, row_max); - cute::fill(scores_scale, 1.f); - } else { - Tensor scores_max_prev = make_fragment_like(row_max); - cute::copy(row_max, scores_max_prev); - reduce_max(scores, row_max); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float scores_max_cur = row_max(mi); - scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - row_sum(mi) *= scores_scale(mi); - } - } - return scores_scale; - }; + template + __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, + float softmax_scale_log2) { + Tensor scores = + make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + reduce_max(scores, row_max); + scale_apply_exp2(scores, row_max, softmax_scale_log2); + reduce_sum(scores, row_sum); + cute::fill(scores_scale, 1.f); + } else { + scale_apply_exp2(scores, row_max, softmax_scale_log2); + reduce_sum(scores, row_sum); + } + return scores_scale; + }; - template - __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) { - Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout())); - static_assert(decltype(size<0>(scores))::value == kNRows); - TensorT scores_scale; - if constexpr (Is_first) { - reduce_max(scores, row_max); - scale_apply_exp2(scores, row_max, softmax_scale_log2); - reduce_sum(scores, row_sum); - cute::fill(scores_scale, 1.f); - } else { - scale_apply_exp2(scores, row_max, softmax_scale_log2); - reduce_sum(scores, row_sum); - } - return scores_scale; - }; - - __forceinline__ __device__ TensorT finalize(float softmax_scale_log2) { - SumOp sum_op; - quad_allreduce_(row_sum, row_sum, sum_op); - TensorT scores_scale; - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - float sum = row_sum(mi); - float inv_sum = 1.0f / sum; - row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); - scores_scale(mi) = inv_sum; - } - return scores_scale; - }; - - template - __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { - Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); - static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); - #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scores_scale(mi); - } - } - }; + __forceinline__ __device__ TensorT finalize(float softmax_scale_log2) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = 1.0f / sum; + row_sum(mi) = + row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + scores_scale(mi) = inv_sum; + } + return scores_scale; + }; + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, + TensorT const &scores_scale) { + Tensor acc_o_rowcol = + make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); +#pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale(mi); + } + } + }; }; diff --git a/custom_ops/gpu_ops/flash_mask_attn/utils.hpp b/custom_ops/gpu_ops/flash_mask_attn/utils.hpp index a80022a08..fe98942ce 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/utils.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/utils.hpp @@ -1,13 +1,23 @@ -/****************************************************************************** - * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. - ******************************************************************************/ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include +#include #include #include -#include -#include #include #include @@ -19,8 +29,8 @@ #include #endif -#include #include // For cute::elect_one_sync() +#include #include #include @@ -29,425 +39,468 @@ using namespace cute; -template +template struct PackedHalf; -template<> +template <> struct PackedHalf { - using Type = __half2; + using Type = __half2; }; -template<> +template <> struct PackedHalf { - using Type = nv_bfloat162; + using Type = nv_bfloat162; }; -template +template __forceinline__ __device__ auto float_2_half2(const float x) { - if constexpr (std::is_same::value) { - return __float2half2_rn(x); - } else { - return __float2bfloat162_rn(x); - } + if constexpr (std::is_same::value) { + return __float2half2_rn(x); + } else { + return __float2bfloat162_rn(x); + } } - struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; + uint4 u; + uint4 v; + uint4 s; + uint4 t; }; - struct uint8 { - uint4 u; - uint4 v; + uint4 u; + uint4 v; }; -template +template struct BytesToType {}; -template<> +template <> struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); + using Type = uint16; + static_assert(sizeof(Type) == 64); }; -template<> +template <> struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); + using Type = uint8; + static_assert(sizeof(Type) == 32); }; -template<> +template <> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> +template <> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> +template <> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> +template <> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> +template <> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; -template +template struct Vec { + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + using Vec_type = typename BytesToType::Type; - using Vec_type = typename BytesToType::Type; + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; + Alias_type data; - Alias_type data; + inline __device__ Vec() {} - inline __device__ Vec() {} - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } + template + inline __device__ void to(Vec &other) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + other.data.elt[it] = S(this->data.elt[it]); } + } - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } + template + inline __device__ void assign(const Op &op) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = op(it); } + } - inline __device__ void load_from(const void *base_ptr) { - this->data.vec = *reinterpret_cast(base_ptr); + inline __device__ void load_from(const void *base_ptr) { + this->data.vec = *reinterpret_cast(base_ptr); + } + + inline __device__ void store_to(void *base_ptr) { + *reinterpret_cast(base_ptr) = this->data.vec; + } + + inline __device__ void add(const Vec &other) { + static_assert(NUM_ELT % 2 == 0); + using type = typename PackedHalf::Type; +#pragma unroll + for (int it = 0; it < NUM_ELT / 2; it++) { + type b = *reinterpret_cast(other.data.elt + it * 2); + *reinterpret_cast(this->data.elt + it * 2) += b; } + } - - inline __device__ void store_to(void *base_ptr) { - *reinterpret_cast(base_ptr) = this->data.vec; + inline __device__ void fma(const Vec &scale, + const Vec &bias) { + static_assert(NUM_ELT % 2 == 0); + using type = typename PackedHalf::Type; +#pragma unroll + for (int it = 0; it < NUM_ELT / 2; it++) { + type a = *reinterpret_cast(scale.data.elt + it * 2); + type b = *reinterpret_cast(bias.data.elt + it * 2); + *reinterpret_cast(this->data.elt + it * 2) += a * b; } + } - inline __device__ void add(const Vec &other) { - static_assert(NUM_ELT % 2 == 0); - using type = typename PackedHalf::Type; - #pragma unroll - for (int it = 0; it < NUM_ELT / 2; it++) { - type b = *reinterpret_cast(other.data.elt + it * 2); - *reinterpret_cast(this->data.elt + it * 2) += b; - } - } - - inline __device__ void fma(const Vec &scale, const Vec &bias) { - static_assert(NUM_ELT % 2 == 0); - using type = typename PackedHalf::Type; - #pragma unroll - for (int it = 0; it < NUM_ELT / 2; it++) { - type a = *reinterpret_cast(scale.data.elt + it * 2); - type b = *reinterpret_cast(bias.data.elt + it * 2); - *reinterpret_cast(this->data.elt + it * 2) += a * b; - } - } - - inline __device__ void set_zero() { - constexpr int size = sizeof(Vec_type) / sizeof(int); - #pragma unroll - for (int i = 0; i < size; ++i) { - (reinterpret_cast(this->data.elt))[i] = 0; - } + inline __device__ void set_zero() { + constexpr int size = sizeof(Vec_type) / sizeof(int); +#pragma unroll + for (int i = 0; i < size; ++i) { + (reinterpret_cast(this->data.elt))[i] = 0; } + } }; -template -inline __device__ void apply_rotary_embedding(Vec& vec, Vec& cos, Vec& sin) { - static_assert(PackSize % 2 == 0); - #pragma unroll - for (int i = 0; i < PackSize / 2; i++) { - const float cos_inv_freq = cos.data.elt[i]; - const float sin_inv_freq = sin.data.elt[i]; - const float v1 = static_cast(vec.data.elt[2 * i]); - const float v2 = static_cast(vec.data.elt[2 * i + 1]); - vec.data.elt[2 * i] = static_cast(cos_inv_freq * v1 - sin_inv_freq * v2); - vec.data.elt[2 * i + 1] = static_cast(sin_inv_freq * v1 + cos_inv_freq * v2); - } +template +inline __device__ void apply_rotary_embedding(Vec &vec, + Vec &cos, + Vec &sin) { + static_assert(PackSize % 2 == 0); +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + const float cos_inv_freq = cos.data.elt[i]; + const float sin_inv_freq = sin.data.elt[i]; + const float v1 = static_cast(vec.data.elt[2 * i]); + const float v2 = static_cast(vec.data.elt[2 * i + 1]); + vec.data.elt[2 * i] = static_cast(cos_inv_freq * v1 - sin_inv_freq * v2); + vec.data.elt[2 * i + 1] = + static_cast(sin_inv_freq * v1 + cos_inv_freq * v2); + } } template -__forceinline__ __device__ void app_mask( - Tensor &tSrS, - const int *mask, - const int &mask_row_id, - const int &col_base) { - const float mask_value = -1000000.0f; - for (int i = 0; i < size(tSrS); i+=8) { - const int col = i * 2 + col_base; - if (col >= mask[mask_row_id]) { - tSrS(i) = mask_value; - } - if (col + 1 >= mask[mask_row_id]) { - tSrS(i + 1) = mask_value; - } - if (col >= mask[mask_row_id + 8]) { - tSrS(i + 2) = mask_value; - } - if (col + 1 >= mask[mask_row_id + 8]) { - tSrS(i + 3) = mask_value; - } - if (col + 8 >= mask[mask_row_id]) { - tSrS(i + 4) = mask_value; - } - if (col + 9 >= mask[mask_row_id]) { - tSrS(i + 5) = mask_value; - } - if (col + 8 >= mask[mask_row_id + 8]) { - tSrS(i + 6) = mask_value; - } - if (col + 9 >= mask[mask_row_id + 8]) { - tSrS(i + 7) = mask_value; - } +__forceinline__ __device__ void app_mask(Tensor &tSrS, + const int *mask, + const int &mask_row_id, + const int &col_base) { + const float mask_value = -1000000.0f; + for (int i = 0; i < size(tSrS); i += 8) { + const int col = i * 2 + col_base; + if (col >= mask[mask_row_id]) { + tSrS(i) = mask_value; } + if (col + 1 >= mask[mask_row_id]) { + tSrS(i + 1) = mask_value; + } + if (col >= mask[mask_row_id + 8]) { + tSrS(i + 2) = mask_value; + } + if (col + 1 >= mask[mask_row_id + 8]) { + tSrS(i + 3) = mask_value; + } + if (col + 8 >= mask[mask_row_id]) { + tSrS(i + 4) = mask_value; + } + if (col + 9 >= mask[mask_row_id]) { + tSrS(i + 5) = mask_value; + } + if (col + 8 >= mask[mask_row_id + 8]) { + tSrS(i + 6) = mask_value; + } + if (col + 9 >= mask[mask_row_id + 8]) { + tSrS(i + 7) = mask_value; + } + } } -template +template struct HalfMax; -template<> +template <> struct HalfMax { - inline __device__ __half2 operator()(const __half2 x, const __half2 y) { - __half2 res; - asm volatile("max.f16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } + inline __device__ __half2 operator()(const __half2 x, const __half2 y) { + __half2 res; + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(*reinterpret_cast(&res)) + : "r"(*reinterpret_cast(&x)), + "r"(*reinterpret_cast(&y))); + return res; + } }; -template<> +template <> struct HalfMax { - inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { - nv_bfloat162 res; - asm volatile("max.bf16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } + inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, + const nv_bfloat162 y) { + nv_bfloat162 res; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(*reinterpret_cast(&res)) + : "r"(*reinterpret_cast(&x)), + "r"(*reinterpret_cast(&y))); + return res; + } }; -template +template struct HalfMin; -template<> +template <> struct HalfMin { - inline __device__ __half2 operator()(const __half2 x, const __half2 y) { - __half2 res; - asm volatile("min.f16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } + inline __device__ __half2 operator()(const __half2 x, const __half2 y) { + __half2 res; + asm volatile("min.f16x2 %0, %1, %2;\n" + : "=r"(*reinterpret_cast(&res)) + : "r"(*reinterpret_cast(&x)), + "r"(*reinterpret_cast(&y))); + return res; + } }; -template<> +template <> struct HalfMin { - inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) { - nv_bfloat162 res; - asm volatile("min.bf16x2 %0, %1, %2;\n" : - "=r"(*reinterpret_cast(&res)) : - "r"(*reinterpret_cast(&x)), - "r"(*reinterpret_cast(&y))); - return res; - } + inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, + const nv_bfloat162 y) { + nv_bfloat162 res; + asm volatile("min.bf16x2 %0, %1, %2;\n" + : "=r"(*reinterpret_cast(&res)) + : "r"(*reinterpret_cast(&x)), + "r"(*reinterpret_cast(&y))); + return res; + } }; -template +template __forceinline__ __device__ void copy( - TiledCopy tiled_copy, Tensor const &S, - Tensor &D, - Tensor const &identity_MN, - const int max_MN = 0) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); - } - } + TiledCopy tiled_copy, + Tensor const &S, + Tensor &D, + Tensor const &identity_MN, + const int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } } + } } template inline __device__ auto convert_type(Tensor const &tensor) { - using From_type = typename Engine::value_type; - constexpr int numel = decltype(size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - auto frag = convert_op(*reinterpret_cast *>(tensor.data())); - return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } -template +template __inline__ __device__ T BlockAllReduce(T val) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ T result_broadcast; - T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); - if (threadIdx.x == 0) { result_broadcast = result; } - __syncthreads(); - return result_broadcast; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ T result_broadcast; + T result = BlockReduce(temp_storage).Reduce(val, ReductionOp()); + if (threadIdx.x == 0) { + result_broadcast = result; + } + __syncthreads(); + return result_broadcast; } -template +template __inline__ __device__ T BlockScanSum(T val) { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; - int aggregate; - BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate); - __syncthreads(); - return val; + int aggregate; + BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate); + __syncthreads(); + return val; } - - -template +template struct MaxOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } + __device__ __forceinline__ T operator()(T const &x, T const &y) { + return x > y ? x : y; + } }; template <> struct MaxOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } + // This is slightly faster + __device__ __forceinline__ float operator()(float const &x, float const &y) { + return max(x, y); + } }; -template +template struct MinOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; } + __device__ __forceinline__ T operator()(T const &x, T const &y) { + return x < y ? x : y; + } }; template <> struct MinOp { -// This is slightly faster -__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); } + // This is slightly faster + __device__ __forceinline__ float operator()(float const &x, float const &y) { + return min(x, y); + } }; - -template +template struct SumOp { -__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } + __device__ __forceinline__ T operator()(T const &x, T const &y) { + return x + y; + } }; -template +template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { - using X = Underscore; - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); - auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) - return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); - static_assert(mma_shape_K == 8 || mma_shape_K == 16); - if constexpr (mma_shape_K == 8) { - return acc_layout; - } else { - auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) - return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); - } - } -}; - -template -__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { - constexpr bool Is_RS = !cute::is_base_of::value; - // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } - warpgroup_fence_operand(tCrC); - if constexpr (arrive) { - warpgroup_arrive(); - } - if constexpr (zero_init) { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), + Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), + get<1>(acc_layout), + make_layout(get<2, 1>(l), get<2>(acc_layout))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; } else { - // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - } - if constexpr (commit) { - warpgroup_commit_batch(); - } - if constexpr (wg_wait >= 0) { warpgroup_wait(); } - warpgroup_fence_operand(tCrC); - if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } -} - - -template -__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { - if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 - static_assert(decltype(size<0, 0>(acc_layout))::value == 2); - static_assert(decltype(size<0, 1>(acc_layout))::value == 2); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = acc_layout; - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); - } else { // SM80 - static_assert(decltype(size<0>(acc_layout))::value == 4); - static_assert(decltype(rank(acc_layout))::value == 3); - auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + auto l = logical_divide( + acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout( + make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } + } }; -template -__inline__ __device__ T WarpAllReduce(T val) { - ReductionOp op; - #pragma unroll - for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { - val = op(val, __shfl_xor_sync(0xffffffff, val, mask)); +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, + Tensor0 const &tCrA, + Tensor1 const &tCrB, + Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take + // const + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; } - return val; + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { + warpgroup_fence_operand(const_cast(tCrA)); + } +} + +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<2>(l))); + } +}; + +template +__inline__ __device__ T WarpAllReduce(T val) { + ReductionOp op; +#pragma unroll + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = op(val, __shfl_xor_sync(0xffffffff, val, mask)); + } + return val; }