format flash_mask_attn

This commit is contained in:
lizhenyun01
2025-11-18 11:45:29 +08:00
parent d0b3bec585
commit cd2c4df64a
6 changed files with 1621 additions and 1239 deletions

View File

@@ -1,7 +1,3 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
#include "kernel_traits.h"
#include "flash_mask_attn_kernel.hpp"
#include "kernel_traits.h"
#include "paddle/extension.h"
#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
@@ -29,139 +25,165 @@ struct cuteType;
template <>
struct cuteType<phi::dtype::float16> {
using type = cutlass::half_t;
using type = cutlass::half_t;
};
template <>
struct cuteType<phi::dtype::bfloat16> {
using type = cutlass::bfloat16_t;
using type = cutlass::bfloat16_t;
};
template <typename T>
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& attn_out,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int batch_size = seq_len_encoder.dims()[0];
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int batch_size = cu_seq_q.dims()[0];
Flash_mask_params params;
memset(&params, 0, sizeof(Flash_mask_params));
paddle::Tensor out = paddle::empty(
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
params.q_ptr = const_cast<T*>(q_input.data<T>());
params.k_ptr = const_cast<T*>(k_input.data<T>());
params.v_ptr = const_cast<T*>(v_input.data<T>());
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_len_q = max_enc_len_this_time;
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
Flash_mask_params params;
memset(&params, 0, sizeof(Flash_mask_params));
using cute_type = typename cuteType<T>::type;
params.q_ptr = const_cast<T*>(q_input.data<T>());
params.k_ptr = const_cast<T*>(k_input.data<T>());
params.v_ptr = const_cast<T*>(v_input.data<T>());
params.o_ptr = const_cast<T*>(out.data<T>());
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_len_q = max_enc_len_this_time;
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
using cute_type = typename cuteType<T>::type;
if (mask) {
params.mask = const_cast<int*>(mask.get().data<int>());
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
} else {
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
if (mask) {
params.mask = const_cast<int*>(mask.get().data<int>());
if (attn_out.dtype() == paddle::DataType::FLOAT16) {
using out_type = phi::dtype::float16;
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
flash_attn_headdim128<kBlockM,
kBlockN,
true,
cute_type,
typename cuteType<out_type>::type>(
params, q_input.stream());
} else if (attn_out.dtype() == paddle::DataType::BFLOAT16) {
using out_type = phi::dtype::bfloat16;
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
flash_attn_headdim128<kBlockM,
kBlockN,
true,
cute_type,
typename cuteType<out_type>::type>(
params, q_input.stream());
}
} else {
if (attn_out.dtype() == paddle::DataType::FLOAT16) {
using out_type = phi::dtype::float16;
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
flash_attn_headdim128<kBlockM,
kBlockN,
false,
cute_type,
typename cuteType<out_type>::type>(
params, q_input.stream());
} else if (attn_out.dtype() == paddle::DataType::BFLOAT16) {
using out_type = phi::dtype::bfloat16;
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
flash_attn_headdim128<kBlockM,
kBlockN,
false,
cute_type,
typename cuteType<out_type>::type>(
params, q_input.stream());
}
}
return {out};
// cudaDeviceSynchronize();
// auto err = cudaGetLastError();
// printf("mask attn err = %d, str = %s\n", err, cudaGetErrorString(err));
}
std::vector<paddle::Tensor> FlashAttentionMask(
const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::optional<paddle::Tensor> &mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
return std::move(
DispatchFlashAttentionMask<T>(
q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time));
}
void FlashAttentionMask(const paddle::Tensor& q_input,
const paddle::Tensor& k_input,
const paddle::Tensor& v_input,
const paddle::Tensor& cu_seq_q,
const paddle::Tensor& cu_seq_k,
const paddle::Tensor& seq_len_encoder,
const paddle::Tensor& attn_out,
const paddle::optional<paddle::Tensor>& mask,
const int head_num,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
DispatchFlashAttentionMask<T>(q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
attn_out,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
DispatchFlashAttentionMask<T>(q_input,
k_input,
v_input,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
attn_out,
mask,
head_num,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time);
}
}
PD_BUILD_STATIC_OP(flash_attention_mask)
.Inputs({
"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"seq_len_encoder",
paddle::Optional("mask")})
.Attrs({
"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int"})
.Outputs({
"out"})
.Inputs({"q_input",
"k_input",
"v_input",
"cu_seq_q",
"cu_seq_k",
"seq_len_encoder",
"attn_out",
paddle::Optional("mask")})
.Attrs({"head_num: int",
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int"})
.Outputs({"out"})
.SetInplaceMap({{"attn_out", "out"}})
.SetKernelFn(PD_KERNEL(FlashAttentionMask));

View File

@@ -1,18 +1,20 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "kernel_traits.h"
#include "mainloop_attn.hpp"
@@ -22,210 +24,247 @@ using namespace cute;
template <int kHeadDim>
auto get_gmem_layout(int token_num, int head_num) {
return make_layout(
make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
return make_layout(make_shape(token_num, kHeadDim, head_num),
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
}
template <typename Ktraits>
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
__global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
1)
compute_attn_ws(
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using output_type = typename Ktraits::output_type;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr bool NeedMask = Ktraits::NeedMask;
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
constexpr int kHeadDim = Ktraits::kHeadDim;
constexpr bool NeedMask = Ktraits::NeedMask;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
__align__(16) __shared__ int mask[kBlockM];
__align__(16) __shared__ int mask[kBlockM];
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
if constexpr (NeedMask) {
const int *mask_this_batch =
data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
if constexpr (NeedMask) {
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
mask[i] = mask_this_batch[i];
}
for (int i = threadIdx.x; i < kBlockM;
i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
mask[i] = mask_this_batch[i];
}
}
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
const int seq_len_q = data_params.seq_len_encoder[bidb];
const int seq_len_k =
data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
if (m_block * kBlockM >= seq_len_q) {
return;
if (m_block * kBlockM >= seq_len_q) {
return;
}
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
int const warp_group_thread_idx =
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(
shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(
shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
const int real_seq = seq_len_q - m_block * kBlockM;
const int n_block_max =
NeedMask
? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN)
: min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q,
kBlockN),
cute::ceil_div(seq_len_k, kBlockN));
;
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup =
__shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k =
cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v =
cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
collective_mainloop.mma_init();
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
}
PipelineState smem_pipe_read_k, smem_pipe_read_v;
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
Tensor tOrO =
partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
PipelineParams pipeline_params;
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0
? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NumMmaThreads;
collective_mainloop.mma(mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
mask,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1);
}
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
__syncthreads();
CollectiveMainloop collective_mainloop;
const int real_seq = seq_len_q - m_block * kBlockM;
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
collective_mainloop.load(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_write_k,
smem_pipe_write_v,
shared_storage,
n_block_max,
m_block,
bidh,
bidb,
data_params.cu_seq_q,
data_params.cu_seq_k,
seq_len_q,
seq_len_k);
}
} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
typename Ktraits::TiledMma1 tiled_mma1;
PipelineState smem_pipe_read_k, smem_pipe_read_v;
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
collective_mainloop.mma(
mainloop_params,
pipeline_k,
pipeline_v,
smem_pipe_read_k,
smem_pipe_read_v,
tOrO,
softmax,
mask,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
seq_len_q,
seq_len_k,
shared_storage);
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
}
const int o_head_stride = data_params.head_num * kHeadDim;
const int store_offset =
(data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride +
bidh * kHeadDim;
collective_mainloop.store<NumMmaThreads>(
mainloop_params,
tOrO,
shared_storage,
tiled_mma1,
threadIdx.x - NumCopyThreads,
o_head_stride,
real_seq,
reinterpret_cast<output_type *>(data_params.o_ptr) + store_offset);
}
}
template<typename Kernel_traits>
template <typename Kernel_traits>
void run_flash_mask(Flash_mask_params &params, cudaStream_t stream) {
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using Element = typename Kernel_traits::Element;
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments({
static_cast<Element const*>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
static_cast<Element const*>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
static_cast<Element const*>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
params.scale_softmax_log2
});
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments(
{static_cast<Element const *>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_q * params.batch_size,
params.head_num),
static_cast<Element const *>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
params.kv_head_num),
static_cast<Element const *>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
params.kv_head_num),
params.scale_softmax_log2});
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
int num_blocks_m =
cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) *
size<0>(ClusterShape{});
void *kernel;
kernel = (void *)compute_attn_ws<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
void *kernel;
kernel = (void *)compute_attn_ws<Kernel_traits>;
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
dim3 grid_dims;
grid_dims.x = num_blocks_m;
grid_dims.y = params.head_num;
grid_dims.z = params.batch_size;
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
dim3 block_dims(ctaSize);
dim3 cluster_dims(size<0>(ClusterShape{}),
size<1>(ClusterShape{}),
size<2>(ClusterShape{}));
cutlass::ClusterLaunchParams launch_params{
grid_dims, block_dims, cluster_dims, smem_size, stream};
cutlass::launch_kernel_on_cluster(
launch_params, kernel, mainloop_params, params);
}
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
template <int kBlockM,
int kBlockN,
bool NeedMask,
typename InputType,
typename OutputType>
void flash_attn_headdim128(Flash_mask_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
constexpr static int Headdim = 128;
constexpr static int kNWarps = kBlockM / 16 + 4;
constexpr static int kStages = 2;
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
run_flash_mask<Ktraits>(params, stream);
using Ktraits = Flash_mask_kernel_traits<Headdim,
kBlockM,
kBlockN,
kNWarps,
kStages,
NeedMask,
InputType,
OutputType>;
run_flash_mask<Ktraits>(params, stream);
}

View File

@@ -1,6 +1,16 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
@@ -15,110 +25,155 @@
using namespace cute;
struct Flash_mask_params {
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void * __restrict__ o_ptr;
int * __restrict__ cu_seq_q;
int * __restrict__ cu_seq_k;
int * __restrict__ mask;
int * seq_len_encoder;
int head_num;
int kv_head_num;
int max_seq_len_q;
int max_seq_len_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
int *__restrict__ cu_seq_q;
int *__restrict__ cu_seq_k;
int *__restrict__ mask;
int *seq_len_encoder;
int head_num;
int kv_head_num;
int max_seq_len_q;
int max_seq_len_k;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;
};
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
template <int kStages,
class Gemm1Type,
class Gemm2Type,
class OutputType,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class SmemLayoutO>
struct SharedStorageQKVO {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
};
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
template <int kHeadDim_,
int kBlockM_,
int kBlockN_,
int kNWarps_,
int kStages_,
bool NeedMask_,
typename elem_type = cutlass::half_t,
typename out_type = cutlass::half_t>
struct Flash_mask_kernel_traits {
using Element = elem_type;
using ElementAccum = float;
using index_t = int32_t;
using Element = elem_type;
using output_type = out_type;
using ElementAccum = float;
using index_t = int32_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
static constexpr int NeedMask = NeedMask_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
static constexpr int kStages = kStages_;
static constexpr int NeedMask = NeedMask_;
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K, GMMA::Major::MN>(),
AtomLayoutMNK{}));
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
using TiledMma0 = decltype(cute::make_tiled_mma(
cute::GMMA::
ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
AtomLayoutMNK{}));
using TiledMma1 = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<Element,
Element,
ElementAccum,
decltype(select<0, 2, 1>(TileShape_MNK{})),
GMMA::Major::K,
GMMA::Major::MN>(),
AtomLayoutMNK{}));
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomQ =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ =
decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomK =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
Int<kStages>{})));
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutAtomV =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtomV{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
Int<kStages>{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemLayoutAtomO =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
output_type,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutO =
decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, output_type>;
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
using SharedStorage = SharedStorageQKVO<kStages,
Element,
Element,
output_type,
SmemLayoutQ,
SmemLayoutK,
SmemLayoutV,
SmemLayoutO>;
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}),
LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{},
TiledCopyOThrLayout{},
TiledCopyOValLayout{}
));
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<output_type>);
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
using TiledCopyOAtom =
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, output_type>;
using TiledCopyOThrLayout = decltype(cute::make_layout(
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
LayoutRight{}));
using TiledCopyOValLayout = decltype(cute::make_layout(
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{}));
using GmemTiledCopyO = decltype(make_tiled_copy(
TiledCopyOAtom{}, TiledCopyOThrLayout{}, TiledCopyOValLayout{}));
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
using PipelineState = typename cutlass::PipelineState<kStages>;
};

View File

@@ -1,13 +1,23 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/tensor.hpp"
@@ -16,416 +26,559 @@
#include "utils.hpp"
using namespace cute;
enum class AttnNamedBarriers {
QueryEmpty = 0,
ValueEmpty = 1,
TileCountSmemEmpty = 2,
TileCountSmemFull = 3,
WarpSchedulerWG1 = 4,
WarpSchedulerWG2 = 5,
WarpSchedulerWG3 = 6,
};
template <typename Ktraits>
struct CollectiveMainloopAttn {
using Element = typename Ktraits::Element;
using output_type = typename Ktraits::output_type;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
using Element = typename Ktraits::Element;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr bool NeedMask = Ktraits::NeedMask;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr bool NeedMask = Ktraits::NeedMask;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int32_t, _1, int32_t>;
using LayoutT = cute::Layout<ShapeT, StrideT>;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV =
decltype(cutlass::gemm::collective::detail::
sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomQ =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ =
decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
using SmemLayoutAtomK =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GMMA::Major::K,
Element,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}),
shape<2>(TileShape_MNK{}),
Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt = decltype(cute::composition(
SmemLayoutV{},
make_layout(
make_shape(
get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}),
_1{},
Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutK =
decltype(tile_to_shape(SmemLayoutAtomK{},
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
using SmemLayoutV = SmemLayoutK;
// Note this is the transpose in terms of the view, not in terms of memory.
using SmemLayoutVt =
decltype(cute::composition(SmemLayoutV{},
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
using SmemLayoutO = typename Ktraits::SmemLayoutO;
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
using TMA_KV = decltype(make_tma_copy(
GmemTiledCopyKV{},
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}),
take<0, 2>(SmemLayoutK{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{})); // no mcast for Q
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
using TMA_KV = decltype(make_tma_copy(
// Set the bytes transferred in this TMA transaction (may involve multiple
// issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(
size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(
size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
static Params to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{});
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
make_tensor(
make_gmem_ptr(static_cast<Element const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)),
StrideT{}
),
take<0, 2>(SmemLayoutK{}),
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q,
args.layout_K,
args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()),
get<2>(args.layout_K.shape()))),
tma_load_Q,
tma_load_K,
tma_load_V,
args.softmax_scale_log2};
}
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_V.get_tma_descriptor());
}
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor,
const Shape& tile_shape,
const int* cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
}
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
template <typename SharedStorage>
CUTLASS_DEVICE void load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage& shared_storage,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int* cu_seq_q,
const int* cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
Tensor sQ =
make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK =
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV =
make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
// Host side kernel arguments
struct Arguments {
Element const* ptr_Q;
LayoutT layout_Q;
Element const* ptr_K;
LayoutT layout_K;
Element const* ptr_V;
LayoutT layout_V;
float const softmax_scale_log2;
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(
mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(
mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(
mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(
_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor sQ_x =
make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x =
make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q,
_0{},
Layout<_1>{},
group_modes<0, 2>(sQ_x),
group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K,
_0{},
Layout<_1>{},
group_modes<0, 2>(sK),
group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V,
_0{},
Layout<_1>{},
group_modes<0, 2>(sV),
group_modes<0, 2>(gV));
uint16_t mcast_mask_kv = 0;
int n_block = n_block_max - 1;
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(
reinterpret_cast<
cutlass::arch::ClusterTransactionBarrier::ValueType&>(
shared_storage.barrier_Q),
0 /*mcast_mask*/),
tQgQ,
tQsQ);
}
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(
*pipeline_k.producer_get_barrier(smem_pipe_write_k),
mcast_mask_kv),
tKgK(_, n_block),
tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(
*pipeline_k.producer_get_barrier(smem_pipe_write_k),
mcast_mask_kv),
tKgK(_, n_block - 1),
tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(
*pipeline_v.producer_get_barrier(smem_pipe_write_v),
mcast_mask_kv),
tVgV(_, n_block),
tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(
*pipeline_v.producer_get_barrier(smem_pipe_write_v),
mcast_mask_kv),
tVgV(_, n_block),
tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
CUTLASS_DEVICE void warp_scheduler_barrier_sync() {
if constexpr (UseSchedulerBarrier) {
cutlass::arch::NamedBarrier::sync(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
cutlass::canonical_warp_group_idx() /*id*/);
}
}
CUTLASS_DEVICE void mma_init() {
if constexpr (!UseSchedulerBarrier) {
return;
}
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup ||
NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if (cutlass::canonical_warp_group_idx() > 1) {
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
}
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
if (cutlass::canonical_warp_group_idx() > 2) {
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
2 /*id*/);
}
}
}
CUTLASS_DEVICE void warp_scheduler_barrier_arrive() {
if constexpr (!UseSchedulerBarrier) {
return;
}
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup ||
NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
(3 - cutlass::canonical_warp_group_idx()) /*id*/);
} else {
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
(cutlass::canonical_warp_group_idx() <= 2
? cutlass::canonical_warp_group_idx() + 1
: cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads,
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
(cutlass::canonical_warp_group_idx() <= 1
? cutlass::canonical_warp_group_idx() + 2
: cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
}
}
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int* mask,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ =
make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK =
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()),
SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
cutlass::FastDivmod qhead_per_khead_divmod;
TMA_Q tma_load_Q;
TMA_KV tma_load_K, tma_load_V;
float const softmax_scale_log2;
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
static Params
to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
TMA_Q tma_load_Q = make_tma_copy(
GmemTiledCopyQ{},
mQ,
SmemLayoutQ{},
select<0, 2>(TileShape_MNK{}),
_1{});
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
TMA_KV tma_load_K = make_tma_copy(
GmemTiledCopyKV{},
mK,
SmemLayoutK{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
TMA_KV tma_load_V = make_tma_copy(
GmemTiledCopyKV{},
mV,
SmemLayoutV{}(_, _, _0{}),
select<1, 2>(TileShape_MNK{}),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
return {args.layout_Q, args.layout_K, args.layout_V,
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
tma_load_Q, tma_load_K, tma_load_V,
args.softmax_scale_log2};
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(
shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) {
shared_storage.barrier_Q.wait(0);
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
Tensor tSrS =
partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
gemm</*zero_init=*/true, /*wg_wait=*/-1>(
tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
int mask_start_idx;
int mask_row_id;
int col_base;
if constexpr (NeedMask) {
const int lane_id = thread_idx % 32;
mask_start_idx = mask[0] / kBlockN - 1;
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
col_base = thread_idx % 4 * 2;
app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN);
} else {
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q +
m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN,
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
template <typename MTensor, typename Shape>
CUTLASS_DEVICE auto get_local_tile_tensor(
const MTensor &m_tensor,
const Shape &tile_shape,
const int *cu_seq_len,
const int bidh,
const int bidb,
const int actual_seq_len) const {
auto g_offset = local_tile(
m_tensor(_, _, bidh),
cute::make_shape(1, get<1>(tile_shape)),
make_coord(cu_seq_len[bidb], _0{}));
auto g_sequence = make_tensor(
g_offset.data(),
make_layout(
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
g_offset.stride()
));
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
return g_tensor;
softmax.template online_softmax</*Is_first=*/true>(
tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(
convert_type<Element>(tSrS).data(),
convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
#pragma unroll 2
for (; n_block > 0; --n_block) {
Tensor tSrS =
partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
warp_scheduler_barrier_sync();
if constexpr (NeedMask) {
if (n_block >= mask_start_idx) {
app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN);
}
}
gemm</*zero_init=*/true, /*wg_wait=*/-1>(
tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warp_scheduler_barrier_arrive();
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(
tSrS, mainloop_params.softmax_scale_log2),
scores_scale);
softmax.template online_softmax</*Is_first=*/false>(
tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(
make_tensor(convert_type<Element>(tSrS).data(),
convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(
tSrS.layout())),
tOrP);
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
template <typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v,
SharedStorage &shared_storage,
const int n_block_max,
const int m_block,
const int bidh,
const int bidb,
const int *cu_seq_q,
const int *cu_seq_k,
const int seq_len_q,
const int seq_len_k) {
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2),
scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
softmax.rescale_o(tOrO, scores_scale);
}
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
template <int NumMmaThreads,
typename SharedStorage,
typename FrgTensorO,
typename TiledMma,
typename T>
CUTLASS_DEVICE void store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T* out_ptr) {
Tensor sO =
make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor gQ = get_local_tile_tensor(
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
Tensor gK = get_local_tile_tensor(
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor gV = get_local_tile_tensor(
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
Tensor tOrO_out = convert_type<output_type>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
cutlass::arch::NamedBarrier::sync(
NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible
// to TMA
cutlass::arch::NamedBarrier::arrive(
NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
uint16_t mcast_mask_kv = 0;
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
int n_block = n_block_max - 1;
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
int lane_predicate = cute::elect_one_sync();
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
}
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (lane_predicate) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
}
if (lane_predicate) {
#pragma unroll 2
for (; n_block > 0; --n_block) {
pipeline_k.producer_acquire(smem_pipe_write_k);
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
}
if (lane_predicate) {
pipeline_v.producer_acquire(smem_pipe_write_v);
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
++smem_pipe_write_v;
}
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma(Params const& mainloop_params,
MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v,
PipelineState& smem_pipe_read_k,
PipelineState& smem_pipe_read_v,
FrgTensorO& tOrO,
Softmax& softmax,
const int *mask,
const int n_block_max,
const int thread_idx,
const int m_block,
const int seq_len_q,
const int seq_len_k,
SharedStorage& shared_storage) {
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
typename Ktraits::TiledMma0 tiled_mma0;
typename Ktraits::TiledMma1 tiled_mma1;
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
};
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
int n_block = n_block_max - 1;
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read_k);
++smem_pipe_read_k;
int mask_start_idx;
int mask_row_id;
int col_base;
if constexpr (NeedMask) {
const int lane_id = thread_idx % 32;
mask_start_idx = mask[0] / kBlockN - 1;
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
col_base = thread_idx % 4 * 2;
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
} else {
auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
};
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if (int(get<1>(tScS(i))) >=
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);
#pragma unroll 1
for (; n_block > 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read_k);
if constexpr (NeedMask) {
if (n_block >= mask_start_idx) {
app_mask(
tSrS,
mask,
mask_row_id,
col_base + n_block * kBlockN);
}
}
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
warpgroup_wait<1>();
pipeline_k.consumer_release(smem_pipe_read_k); // release K
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v); // release V
++smem_pipe_read_k;
++smem_pipe_read_v;
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
}
softmax.rescale_o(tOrO, scores_scale);
consumer_wait(pipeline_v, smem_pipe_read_v);
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
warpgroup_wait<0>();
pipeline_v.consumer_release(smem_pipe_read_v);
++smem_pipe_read_v;
softmax.rescale_o(tOrO, scores_scale);
return;
}
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
CUTLASS_DEVICE void
store(Params const& mainloop_params,
FrgTensorO const& tOrO,
SharedStorage& shared_storage,
TiledMma tiled_mma,
int thread_idx,
const int o_head_stride,
const int real_seq,
T * out_ptr) {
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOrO_out = convert_type<Element>(tOrO);
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(o_head_stride, _1{}));
GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
if (real_seq >= kBlockM) {
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
} else {
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
}
}
}
};

View File

@@ -1,6 +1,16 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
@@ -12,195 +22,245 @@
#include "utils.hpp"
using namespace cute;
template<int THREADS>
template <int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
template<>
template <>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
template <typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
}
};
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
template <bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ __forceinline__ void thread_reduce_(
Tensor<Engine0, Layout0> const &tensor,
Tensor<Engine1, Layout1> &summary,
Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
template <typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst,
Tensor<Engine1, Layout1> &src,
Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++) {
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
template <bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const &tensor,
Tensor<Engine1, Layout1> &summary,
Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
template <bool zero_init = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1>
__device__ __forceinline__ void reduce_max(
Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &max) {
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
template <bool zero_init = true,
bool warp_reduce = true,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1>
__device__ __forceinline__ void reduce_sum(
Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &sum) {
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
if constexpr (warp_reduce) {
quad_allreduce_(sum, sum, sum_op);
}
}
__forceinline__ __device__ __half2 half_exp(__half2 x) {
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t&>(x);
asm ("ex2.approx.f16x2 %0, %1;\n"
: "=r"(tmp_out)
: "r"(tmp_in));
__half2 out = reinterpret_cast<__half2&>(tmp_out);
return out;
uint32_t tmp_out, tmp_in;
tmp_in = reinterpret_cast<uint32_t &>(x);
asm("ex2.approx.f16x2 %0, %1;\n" : "=r"(tmp_out) : "r"(tmp_in));
__half2 out = reinterpret_cast<__half2 &>(tmp_out);
return out;
}
// Apply the exp to all the elements.
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
template <bool zero_init = false,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(
Tensor<Engine0, Layout0> &tensor,
Tensor<Engine1, Layout1> &max,
Tensor<Engine1, Layout1> &sum,
const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
template <typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1>
__forceinline__ __device__ void scale_apply_exp2(
Tensor<Engine0, Layout0> &tensor,
Tensor<Engine1, Layout1> const &max,
const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
const float max_scaled = max(mi) * scale;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
}
}
}
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
CUTLASS_DEVICE Softmax() {};
template <bool Is_first, bool Check_inf = false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s,
float softmax_scale_log2) {
Tensor scores =
make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) =
exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
cute::fill(scores_scale, 1.f);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
reduce_max</*zero_init=*/false>(scores, row_max);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = row_max(mi);
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
row_sum(mi) *= scores_scale(mi);
}
}
return scores_scale;
};
template <bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s,
float softmax_scale_log2) {
Tensor scores =
make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
template<bool Is_first, typename Tensor0>
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scores_scale;
if constexpr (Is_first) {
reduce_max</*zero_init=*/true>(scores, row_max);
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
cute::fill(scores_scale, 1.f);
} else {
scale_apply_exp2(scores, row_max, softmax_scale_log2);
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
}
return scores_scale;
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template<typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT scores_scale;
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float sum = row_sum(mi);
float inv_sum = 1.0f / sum;
row_sum(mi) =
row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
scores_scale(mi) = inv_sum;
}
return scores_scale;
};
template <typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o,
TensorT const &scores_scale) {
Tensor acc_o_rowcol =
make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
acc_o_rowcol(mi, ni) *= scores_scale(mi);
}
}
};
};

View File

@@ -1,13 +1,23 @@
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <fstream>
#include <iostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <assert.h>
#include <stdint.h>
@@ -19,8 +29,8 @@
#include <cuda_bf16.h>
#endif
#include <cute/tensor.hpp>
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
@@ -29,425 +39,468 @@
using namespace cute;
template<typename T>
template <typename T>
struct PackedHalf;
template<>
template <>
struct PackedHalf<cutlass::half_t> {
using Type = __half2;
using Type = __half2;
};
template<>
template <>
struct PackedHalf<cutlass::bfloat16_t> {
using Type = nv_bfloat162;
using Type = nv_bfloat162;
};
template<typename T>
template <typename T>
__forceinline__ __device__ auto float_2_half2(const float x) {
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
if constexpr (std::is_same<T, cutlass::half_t>::value) {
return __float2half2_rn(x);
} else {
return __float2bfloat162_rn(x);
}
}
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
struct uint8 {
uint4 u;
uint4 v;
uint4 u;
uint4 v;
};
template<int BYTES>
template <int BYTES>
struct BytesToType {};
template<>
template <>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
template <>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
template <>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
template <>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
template <>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
template <>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
template <>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
template<typename Elt_type, uint32_t NUM_ELT>
template <typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
Alias_type data;
inline __device__ Vec() {}
inline __device__ Vec() {}
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
template <typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
template <typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for (int it = 0; it < NUM_ELT; it++) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
inline __device__ void load_from(const void *base_ptr) {
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void store_to(void *base_ptr) {
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale,
const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
}
}
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
static_assert(NUM_ELT % 2 == 0);
using type = typename PackedHalf<Elt_type>::Type;
#pragma unroll
for (int it = 0; it < NUM_ELT / 2; it++) {
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
}
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
inline __device__ void set_zero() {
constexpr int size = sizeof(Vec_type) / sizeof(int);
#pragma unroll
for (int i = 0; i < size; ++i) {
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
}
}
};
template<typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
template <typename T, int PackSize>
inline __device__ void apply_rotary_embedding(Vec<T, PackSize> &vec,
Vec<float, PackSize / 2> &cos,
Vec<float, PackSize / 2> &sin) {
static_assert(PackSize % 2 == 0);
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
const float cos_inv_freq = cos.data.elt[i];
const float sin_inv_freq = sin.data.elt[i];
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
vec.data.elt[2 * i + 1] =
static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
}
}
template <typename Tensor>
__forceinline__ __device__ void app_mask(
Tensor &tSrS,
const int *mask,
const int &mask_row_id,
const int &col_base) {
const float mask_value = -1000000.0f;
for (int i = 0; i < size(tSrS); i+=8) {
const int col = i * 2 + col_base;
if (col >= mask[mask_row_id]) {
tSrS(i) = mask_value;
}
if (col + 1 >= mask[mask_row_id]) {
tSrS(i + 1) = mask_value;
}
if (col >= mask[mask_row_id + 8]) {
tSrS(i + 2) = mask_value;
}
if (col + 1 >= mask[mask_row_id + 8]) {
tSrS(i + 3) = mask_value;
}
if (col + 8 >= mask[mask_row_id]) {
tSrS(i + 4) = mask_value;
}
if (col + 9 >= mask[mask_row_id]) {
tSrS(i + 5) = mask_value;
}
if (col + 8 >= mask[mask_row_id + 8]) {
tSrS(i + 6) = mask_value;
}
if (col + 9 >= mask[mask_row_id + 8]) {
tSrS(i + 7) = mask_value;
}
__forceinline__ __device__ void app_mask(Tensor &tSrS,
const int *mask,
const int &mask_row_id,
const int &col_base) {
const float mask_value = -1000000.0f;
for (int i = 0; i < size(tSrS); i += 8) {
const int col = i * 2 + col_base;
if (col >= mask[mask_row_id]) {
tSrS(i) = mask_value;
}
if (col + 1 >= mask[mask_row_id]) {
tSrS(i + 1) = mask_value;
}
if (col >= mask[mask_row_id + 8]) {
tSrS(i + 2) = mask_value;
}
if (col + 1 >= mask[mask_row_id + 8]) {
tSrS(i + 3) = mask_value;
}
if (col + 8 >= mask[mask_row_id]) {
tSrS(i + 4) = mask_value;
}
if (col + 9 >= mask[mask_row_id]) {
tSrS(i + 5) = mask_value;
}
if (col + 8 >= mask[mask_row_id + 8]) {
tSrS(i + 6) = mask_value;
}
if (col + 9 >= mask[mask_row_id + 8]) {
tSrS(i + 7) = mask_value;
}
}
}
template<typename T>
template <typename T>
struct HalfMax;
template<>
template <>
struct HalfMax<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("max.f16x2 %0, %1, %2;\n"
: "=r"(*reinterpret_cast<uint32_t *>(&res))
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
"r"(*reinterpret_cast<const uint32_t *>(&y)));
return res;
}
};
template<>
template <>
struct HalfMax<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x,
const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("max.bf16x2 %0, %1, %2;\n"
: "=r"(*reinterpret_cast<uint32_t *>(&res))
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
"r"(*reinterpret_cast<const uint32_t *>(&y)));
return res;
}
};
template<typename T>
template <typename T>
struct HalfMin;
template<>
template <>
struct HalfMin<cutlass::half_t> {
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
__half2 res;
asm volatile("min.f16x2 %0, %1, %2;\n"
: "=r"(*reinterpret_cast<uint32_t *>(&res))
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
"r"(*reinterpret_cast<const uint32_t *>(&y)));
return res;
}
};
template<>
template <>
struct HalfMin<cutlass::bfloat16_t> {
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n" :
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
"r"(*reinterpret_cast<const uint32_t*>(&x)),
"r"(*reinterpret_cast<const uint32_t*>(&y)));
return res;
}
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x,
const nv_bfloat162 y) {
nv_bfloat162 res;
asm volatile("min.bf16x2 %0, %1, %2;\n"
: "=r"(*reinterpret_cast<uint32_t *>(&res))
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
"r"(*reinterpret_cast<const uint32_t *>(&y)));
return res;
}
};
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
template <bool Is_even_MN = true,
typename TiledCopy,
typename Engine0,
typename Layout0,
typename Engine1,
typename Layout1,
typename Engine2,
typename Layout2>
__forceinline__ __device__ void copy(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
TiledCopy tiled_copy,
Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &identity_MN,
const int max_MN = 0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
}
}
}
}
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
template<typename T, typename ReductionOp, int block_size>
template <typename T, typename ReductionOp, int block_size>
__inline__ __device__ T BlockAllReduce(T val) {
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) { result_broadcast = result; }
__syncthreads();
return result_broadcast;
typedef cub::BlockReduce<T, block_size> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T result_broadcast;
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
if (threadIdx.x == 0) {
result_broadcast = result;
}
__syncthreads();
return result_broadcast;
}
template<typename T, int block_size>
template <typename T, int block_size>
__inline__ __device__ T BlockScanSum(T val) {
typedef cub::BlockScan<int, block_size> BlockScanT;
__shared__ typename BlockScanT::TempStorage temp_storage;
typedef cub::BlockScan<int, block_size> BlockScanT;
__shared__ typename BlockScanT::TempStorage temp_storage;
int aggregate;
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
__syncthreads();
return val;
int aggregate;
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
__syncthreads();
return val;
}
template<typename T>
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
__device__ __forceinline__ T operator()(T const &x, T const &y) {
return x > y ? x : y;
}
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) {
return max(x, y);
}
};
template<typename T>
template <typename T>
struct MinOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
__device__ __forceinline__ T operator()(T const &x, T const &y) {
return x < y ? x : y;
}
};
template <>
struct MinOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) {
return min(x, y);
}
};
template<typename T>
template <typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
__device__ __forceinline__ T operator()(T const &x, T const &y) {
return x + y;
}
};
template<typename MMA_traits, typename Layout>
template <typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
auto l = logical_divide(get<0>(acc_layout),
Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)),
get<1>(acc_layout),
make_layout(get<2, 1>(l), get<2>(acc_layout)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
template<typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
auto l = logical_divide(
acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(
make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
template<typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
template <bool zero_init = false,
int wg_wait = 0,
bool arrive = true,
bool commit = true,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
Tensor0 const &tCrA,
Tensor1 const &tCrB,
Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take
// const
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
return val;
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) {
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
}
}
template <typename Layout>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
make_layout(get<0, 0>(l), get<2>(l)));
}
};
template <typename T, typename ReductionOp, int thread_group_width = 32>
__inline__ __device__ T WarpAllReduce(T val) {
ReductionOp op;
#pragma unroll
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
}
return val;
}