mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
format flash_mask_attn
This commit is contained in:
@@ -1,7 +1,3 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -16,9 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "flash_mask_attn_kernel.hpp"
|
||||
#include "kernel_traits.h"
|
||||
#include "paddle/extension.h"
|
||||
|
||||
#ifndef PD_BUILD_STATIC_OP
|
||||
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
|
||||
@@ -29,139 +25,165 @@ struct cuteType;
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::float16> {
|
||||
using type = cutlass::half_t;
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct cuteType<phi::dtype::bfloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
std::vector<paddle::Tensor> DispatchFlashAttentionMask(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::optional<paddle::Tensor>& mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& attn_out,
|
||||
const paddle::optional<paddle::Tensor>& mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int batch_size = cu_seq_q.dims()[0];
|
||||
Flash_mask_params params;
|
||||
memset(¶ms, 0, sizeof(Flash_mask_params));
|
||||
|
||||
paddle::Tensor out = paddle::empty(
|
||||
{q_input.dims()[0], head_num * head_dim}, q_input.dtype(), q_input.place());
|
||||
params.q_ptr = const_cast<T*>(q_input.data<T>());
|
||||
params.k_ptr = const_cast<T*>(k_input.data<T>());
|
||||
params.v_ptr = const_cast<T*>(v_input.data<T>());
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_len_q = max_enc_len_this_time;
|
||||
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
|
||||
|
||||
Flash_mask_params params;
|
||||
memset(¶ms, 0, sizeof(Flash_mask_params));
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
|
||||
params.q_ptr = const_cast<T*>(q_input.data<T>());
|
||||
params.k_ptr = const_cast<T*>(k_input.data<T>());
|
||||
params.v_ptr = const_cast<T*>(v_input.data<T>());
|
||||
params.o_ptr = const_cast<T*>(out.data<T>());
|
||||
params.cu_seq_q = const_cast<int*>(cu_seq_q.data<int>());
|
||||
params.cu_seq_k = const_cast<int*>(cu_seq_k.data<int>());
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_len_q = max_enc_len_this_time;
|
||||
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
params.scale_softmax_log2 = 1.0f / std::sqrt(head_dim) * kLog2e;
|
||||
|
||||
using cute_type = typename cuteType<T>::type;
|
||||
|
||||
if (mask) {
|
||||
params.mask = const_cast<int*>(mask.get().data<int>());
|
||||
flash_attn_headdim128<kBlockM, kBlockN, true, cute_type>(params, 0);
|
||||
} else {
|
||||
flash_attn_headdim128<kBlockM, kBlockN, false, cute_type>(params, 0);
|
||||
if (mask) {
|
||||
params.mask = const_cast<int*>(mask.get().data<int>());
|
||||
if (attn_out.dtype() == paddle::DataType::FLOAT16) {
|
||||
using out_type = phi::dtype::float16;
|
||||
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
|
||||
flash_attn_headdim128<kBlockM,
|
||||
kBlockN,
|
||||
true,
|
||||
cute_type,
|
||||
typename cuteType<out_type>::type>(
|
||||
params, q_input.stream());
|
||||
} else if (attn_out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using out_type = phi::dtype::bfloat16;
|
||||
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
|
||||
flash_attn_headdim128<kBlockM,
|
||||
kBlockN,
|
||||
true,
|
||||
cute_type,
|
||||
typename cuteType<out_type>::type>(
|
||||
params, q_input.stream());
|
||||
}
|
||||
} else {
|
||||
if (attn_out.dtype() == paddle::DataType::FLOAT16) {
|
||||
using out_type = phi::dtype::float16;
|
||||
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
|
||||
flash_attn_headdim128<kBlockM,
|
||||
kBlockN,
|
||||
false,
|
||||
cute_type,
|
||||
typename cuteType<out_type>::type>(
|
||||
params, q_input.stream());
|
||||
} else if (attn_out.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using out_type = phi::dtype::bfloat16;
|
||||
params.o_ptr = const_cast<out_type*>(attn_out.data<out_type>());
|
||||
flash_attn_headdim128<kBlockM,
|
||||
kBlockN,
|
||||
false,
|
||||
cute_type,
|
||||
typename cuteType<out_type>::type>(
|
||||
params, q_input.stream());
|
||||
}
|
||||
}
|
||||
|
||||
return {out};
|
||||
// cudaDeviceSynchronize();
|
||||
// auto err = cudaGetLastError();
|
||||
// printf("mask attn err = %d, str = %s\n", err, cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
|
||||
std::vector<paddle::Tensor> FlashAttentionMask(
|
||||
const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::optional<paddle::Tensor> &mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
return std::move(
|
||||
DispatchFlashAttentionMask<T>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time));
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
return std::move(
|
||||
DispatchFlashAttentionMask<T>(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time));
|
||||
}
|
||||
|
||||
void FlashAttentionMask(const paddle::Tensor& q_input,
|
||||
const paddle::Tensor& k_input,
|
||||
const paddle::Tensor& v_input,
|
||||
const paddle::Tensor& cu_seq_q,
|
||||
const paddle::Tensor& cu_seq_k,
|
||||
const paddle::Tensor& seq_len_encoder,
|
||||
const paddle::Tensor& attn_out,
|
||||
const paddle::optional<paddle::Tensor>& mask,
|
||||
const int head_num,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
DispatchFlashAttentionMask<T>(q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
attn_out,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
DispatchFlashAttentionMask<T>(q_input,
|
||||
k_input,
|
||||
v_input,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
attn_out,
|
||||
mask,
|
||||
head_num,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PD_BUILD_STATIC_OP(flash_attention_mask)
|
||||
.Inputs({
|
||||
"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
paddle::Optional("mask")})
|
||||
.Attrs({
|
||||
"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int"})
|
||||
.Outputs({
|
||||
"out"})
|
||||
.Inputs({"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
"cu_seq_q",
|
||||
"cu_seq_k",
|
||||
"seq_len_encoder",
|
||||
"attn_out",
|
||||
paddle::Optional("mask")})
|
||||
.Attrs({"head_num: int",
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int"})
|
||||
.Outputs({"out"})
|
||||
.SetInplaceMap({{"attn_out", "out"}})
|
||||
.SetKernelFn(PD_KERNEL(FlashAttentionMask));
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
|
||||
#include "kernel_traits.h"
|
||||
#include "mainloop_attn.hpp"
|
||||
@@ -22,210 +24,247 @@ using namespace cute;
|
||||
|
||||
template <int kHeadDim>
|
||||
auto get_gmem_layout(int token_num, int head_num) {
|
||||
return make_layout(
|
||||
make_shape(token_num, kHeadDim, head_num),
|
||||
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
|
||||
return make_layout(make_shape(token_num, kHeadDim, head_num),
|
||||
make_stride(head_num * kHeadDim, cute::_1{}, kHeadDim));
|
||||
}
|
||||
|
||||
template <typename Ktraits>
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1)
|
||||
__global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
1)
|
||||
compute_attn_ws(
|
||||
CUTE_GRID_CONSTANT typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT
|
||||
typename CollectiveMainloopAttn<Ktraits>::Params const mainloop_params,
|
||||
CUTE_GRID_CONSTANT Flash_mask_params const data_params) {
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
using output_type = typename Ktraits::output_type;
|
||||
using SoftType = ElementAccum;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using ElementAccum = typename Ktraits::ElementAccum;
|
||||
using SoftType = ElementAccum;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Ktraits>;
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage =
|
||||
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
|
||||
|
||||
extern __shared__ char shared_memory[];
|
||||
auto &shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
|
||||
__align__(16) __shared__ int mask[kBlockM];
|
||||
|
||||
__align__(16) __shared__ int mask[kBlockM];
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
if constexpr (NeedMask) {
|
||||
const int *mask_this_batch =
|
||||
data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int *mask_this_batch = data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
|
||||
|
||||
for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
mask[i] = mask_this_batch[i];
|
||||
}
|
||||
for (int i = threadIdx.x; i < kBlockM;
|
||||
i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
mask[i] = mask_this_batch[i];
|
||||
}
|
||||
}
|
||||
|
||||
const int seq_len_q = data_params.seq_len_encoder[bidb];
|
||||
const int seq_len_k = data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
|
||||
const int seq_len_q = data_params.seq_len_encoder[bidb];
|
||||
const int seq_len_k =
|
||||
data_params.cu_seq_k[bidb + 1] - data_params.cu_seq_k[bidb];
|
||||
|
||||
if (m_block * kBlockM >= seq_len_q) {
|
||||
return;
|
||||
if (m_block * kBlockM >= seq_len_q) {
|
||||
return;
|
||||
}
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
|
||||
int const warp_group_thread_idx =
|
||||
threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
shared_storage.barrier_Q.init(1);
|
||||
}
|
||||
|
||||
MainloopPipeline pipeline_k(
|
||||
shared_storage.pipeline_k, pipeline_params, ClusterShape{});
|
||||
MainloopPipeline pipeline_v(
|
||||
shared_storage.pipeline_v, pipeline_params, ClusterShape{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
const int real_seq = seq_len_q - m_block * kBlockM;
|
||||
|
||||
const int n_block_max =
|
||||
NeedMask
|
||||
? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN)
|
||||
: min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q,
|
||||
kBlockN),
|
||||
cute::ceil_div(seq_len_k, kBlockN));
|
||||
;
|
||||
|
||||
if (warp_group_idx == 0) { // Producer
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
|
||||
|
||||
int warp_idx_in_warpgroup =
|
||||
__shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
|
||||
PipelineState smem_pipe_write_k =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState smem_pipe_write_v =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
collective_mainloop.load(mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_write_k,
|
||||
smem_pipe_write_v,
|
||||
shared_storage,
|
||||
n_block_max,
|
||||
m_block,
|
||||
bidh,
|
||||
bidb,
|
||||
data_params.cu_seq_q,
|
||||
data_params.cu_seq_k,
|
||||
seq_len_q,
|
||||
seq_len_k);
|
||||
}
|
||||
} else { // Consumer
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
int const lane_predicate = cute::elect_one_sync();
|
||||
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
collective_mainloop.mma_init();
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
|
||||
}
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
|
||||
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
||||
Tensor tOrO =
|
||||
partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
PipelineParams pipeline_params;
|
||||
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
||||
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
||||
pipeline_params.role = warp_group_idx == 0
|
||||
? MainloopPipeline::ThreadCategory::Producer
|
||||
: MainloopPipeline::ThreadCategory::Consumer;
|
||||
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
pipeline_params.num_consumers = NumMmaThreads;
|
||||
collective_mainloop.mma(mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_read_k,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
mask,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
shared_storage);
|
||||
|
||||
if (warp_idx == 0 && lane_predicate) {
|
||||
shared_storage.barrier_Q.init(1);
|
||||
}
|
||||
|
||||
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
|
||||
MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
CollectiveMainloop collective_mainloop;
|
||||
|
||||
const int real_seq = seq_len_q - m_block * kBlockM;
|
||||
|
||||
const int n_block_max = NeedMask ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) : cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN);
|
||||
|
||||
if (warp_group_idx == 0) { // Producer
|
||||
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 8 ? 56 : 24>();
|
||||
|
||||
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
||||
if (warp_idx_in_warpgroup == 0) { // Load Q, K, V
|
||||
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
|
||||
collective_mainloop.load(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_write_k,
|
||||
smem_pipe_write_v,
|
||||
shared_storage,
|
||||
n_block_max,
|
||||
m_block,
|
||||
bidh,
|
||||
bidb,
|
||||
data_params.cu_seq_q,
|
||||
data_params.cu_seq_k,
|
||||
seq_len_q,
|
||||
seq_len_k);
|
||||
}
|
||||
} else { // Consumer
|
||||
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 8 ? 256 : 240>();
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
|
||||
PipelineState smem_pipe_read_k, smem_pipe_read_v;
|
||||
|
||||
Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{}));
|
||||
Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax;
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_params,
|
||||
pipeline_k,
|
||||
pipeline_v,
|
||||
smem_pipe_read_k,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
mask,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
shared_storage);
|
||||
|
||||
const int o_head_stride = data_params.head_num * kHeadDim;
|
||||
const int store_offset = (data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride + bidh * kHeadDim;
|
||||
|
||||
collective_mainloop.store<NumMmaThreads>(
|
||||
mainloop_params,
|
||||
tOrO,
|
||||
shared_storage,
|
||||
tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
o_head_stride,
|
||||
real_seq,
|
||||
reinterpret_cast<Element*>(data_params.o_ptr) + store_offset);
|
||||
}
|
||||
const int o_head_stride = data_params.head_num * kHeadDim;
|
||||
const int store_offset =
|
||||
(data_params.cu_seq_q[bidb] + m_block * kBlockM) * o_head_stride +
|
||||
bidh * kHeadDim;
|
||||
|
||||
collective_mainloop.store<NumMmaThreads>(
|
||||
mainloop_params,
|
||||
tOrO,
|
||||
shared_storage,
|
||||
tiled_mma1,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
o_head_stride,
|
||||
real_seq,
|
||||
reinterpret_cast<output_type *>(data_params.o_ptr) + store_offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename Kernel_traits>
|
||||
template <typename Kernel_traits>
|
||||
void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using TileShape_MNK = typename Kernel_traits::TileShape_MNK;
|
||||
using ClusterShape = typename Kernel_traits::ClusterShape_MNK;
|
||||
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
using CollectiveMainloop = CollectiveMainloopAttn<Kernel_traits>;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments({
|
||||
static_cast<Element const*>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_q, params.head_num),
|
||||
static_cast<Element const*>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
|
||||
static_cast<Element const*>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k, params.kv_head_num),
|
||||
params.scale_softmax_log2
|
||||
});
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
{static_cast<Element const *>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_q * params.batch_size,
|
||||
params.head_num),
|
||||
static_cast<Element const *>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
|
||||
params.kv_head_num),
|
||||
static_cast<Element const *>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
|
||||
params.kv_head_num),
|
||||
params.scale_softmax_log2});
|
||||
|
||||
int num_blocks_m = cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
|
||||
int num_blocks_m =
|
||||
cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
|
||||
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) *
|
||||
size<0>(ClusterShape{});
|
||||
|
||||
void *kernel;
|
||||
kernel = (void *)compute_attn_ws<Kernel_traits>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
void *kernel;
|
||||
kernel = (void *)compute_attn_ws<Kernel_traits>;
|
||||
int smem_size = sizeof(typename Kernel_traits::SharedStorage);
|
||||
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = num_blocks_m;
|
||||
grid_dims.y = params.head_num;
|
||||
grid_dims.z = params.batch_size;
|
||||
dim3 grid_dims;
|
||||
grid_dims.x = num_blocks_m;
|
||||
grid_dims.y = params.head_num;
|
||||
grid_dims.z = params.batch_size;
|
||||
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, params);
|
||||
static constexpr int ctaSize = Kernel_traits::kNWarps * 32;
|
||||
dim3 block_dims(ctaSize);
|
||||
dim3 cluster_dims(size<0>(ClusterShape{}),
|
||||
size<1>(ClusterShape{}),
|
||||
size<2>(ClusterShape{}));
|
||||
cutlass::ClusterLaunchParams launch_params{
|
||||
grid_dims, block_dims, cluster_dims, smem_size, stream};
|
||||
cutlass::launch_kernel_on_cluster(
|
||||
launch_params, kernel, mainloop_params, params);
|
||||
}
|
||||
|
||||
template <int kBlockM, int kBlockN, bool NeedMask, typename InputType>
|
||||
template <int kBlockM,
|
||||
int kBlockN,
|
||||
bool NeedMask,
|
||||
typename InputType,
|
||||
typename OutputType>
|
||||
void flash_attn_headdim128(Flash_mask_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr static int kNWarps = kBlockM / 16 + 4;
|
||||
constexpr static int kStages = 2;
|
||||
|
||||
constexpr static int Headdim = 128;
|
||||
constexpr static int kNWarps = kBlockM / 16 + 4;
|
||||
constexpr static int kStages = 2;
|
||||
|
||||
using Ktraits = Flash_mask_kernel_traits<Headdim, kBlockM, kBlockN, kNWarps, kStages, NeedMask, InputType>;
|
||||
run_flash_mask<Ktraits>(params, stream);
|
||||
using Ktraits = Flash_mask_kernel_traits<Headdim,
|
||||
kBlockM,
|
||||
kBlockN,
|
||||
kNWarps,
|
||||
kStages,
|
||||
NeedMask,
|
||||
InputType,
|
||||
OutputType>;
|
||||
run_flash_mask<Ktraits>(params, stream);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,110 +25,155 @@
|
||||
using namespace cute;
|
||||
|
||||
struct Flash_mask_params {
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void * __restrict__ o_ptr;
|
||||
int * __restrict__ cu_seq_q;
|
||||
int * __restrict__ cu_seq_k;
|
||||
int * __restrict__ mask;
|
||||
int * seq_len_encoder;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_len_q;
|
||||
int max_seq_len_k;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
void *__restrict__ o_ptr;
|
||||
int *__restrict__ cu_seq_q;
|
||||
int *__restrict__ cu_seq_k;
|
||||
int *__restrict__ mask;
|
||||
int *seq_len_encoder;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_len_q;
|
||||
int max_seq_len_k;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
};
|
||||
|
||||
template <int kStages, class Gemm1Type, class Gemm2Type, class OutputType, class SmemLayoutQ,
|
||||
class SmemLayoutK, class SmemLayoutV, class SmemLayoutO>
|
||||
template <int kStages,
|
||||
class Gemm1Type,
|
||||
class Gemm2Type,
|
||||
class OutputType,
|
||||
class SmemLayoutQ,
|
||||
class SmemLayoutK,
|
||||
class SmemLayoutV,
|
||||
class SmemLayoutO>
|
||||
struct SharedStorageQKVO {
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
union {
|
||||
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
|
||||
};
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
|
||||
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
|
||||
union {
|
||||
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
|
||||
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
|
||||
};
|
||||
struct {
|
||||
cutlass::arch::ClusterTransactionBarrier barrier_Q;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
|
||||
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
|
||||
};
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool NeedMask_, typename elem_type=cutlass::half_t>
|
||||
template <int kHeadDim_,
|
||||
int kBlockM_,
|
||||
int kBlockN_,
|
||||
int kNWarps_,
|
||||
int kStages_,
|
||||
bool NeedMask_,
|
||||
typename elem_type = cutlass::half_t,
|
||||
typename out_type = cutlass::half_t>
|
||||
struct Flash_mask_kernel_traits {
|
||||
using Element = elem_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int32_t;
|
||||
using Element = elem_type;
|
||||
using output_type = out_type;
|
||||
using ElementAccum = float;
|
||||
using index_t = int32_t;
|
||||
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
|
||||
static constexpr int kStages = kStages_;
|
||||
static constexpr int NeedMask = NeedMask_;
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
using TileShape_MNK = Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
||||
using ClusterShape_MNK = Shape<Int<1>, Int<1>, Int<1>>;
|
||||
static constexpr int kStages = kStages_;
|
||||
static constexpr int NeedMask = NeedMask_;
|
||||
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
using TiledMma0 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K, GMMA::Major::MN>(),
|
||||
AtomLayoutMNK{}));
|
||||
using AtomLayoutMNK = Layout<Shape<Int<kBlockM / 64>, _1, _1>>;
|
||||
using TiledMma0 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::
|
||||
ss_op_selector<Element, Element, ElementAccum, TileShape_MNK>(),
|
||||
AtomLayoutMNK{}));
|
||||
using TiledMma1 = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::rs_op_selector<Element,
|
||||
Element,
|
||||
ElementAccum,
|
||||
decltype(select<0, 2, 1>(TileShape_MNK{})),
|
||||
GMMA::Major::K,
|
||||
GMMA::Major::MN>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
using SmemLayoutAtomQ =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ =
|
||||
decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutAtomK =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}),
|
||||
shape<2>(TileShape_MNK{}),
|
||||
Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutAtomV =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutV =
|
||||
decltype(tile_to_shape(SmemLayoutAtomV{},
|
||||
make_shape(shape<1>(TileShape_MNK{}),
|
||||
shape<2>(TileShape_MNK{}),
|
||||
Int<kStages>{})));
|
||||
|
||||
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
using SmemLayoutAtomO =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
output_type,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutO =
|
||||
decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, Element>;
|
||||
using SmemCopyAtomQ = Copy_Atom<cute::SM75_U32x4_LDSM_N, Element>;
|
||||
using SmemCopyAtomO = Copy_Atom<cute::SM90_U32x4_STSM_N, output_type>;
|
||||
|
||||
using SharedStorage = SharedStorageQKVO<kStages, Element, Element, Element, SmemLayoutQ, SmemLayoutK, SmemLayoutV, SmemLayoutO>;
|
||||
using SharedStorage = SharedStorageQKVO<kStages,
|
||||
Element,
|
||||
Element,
|
||||
output_type,
|
||||
SmemLayoutQ,
|
||||
SmemLayoutK,
|
||||
SmemLayoutV,
|
||||
SmemLayoutO>;
|
||||
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<Element>);
|
||||
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
||||
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyOAtom = cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, Element>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyOValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}),
|
||||
LayoutRight{}));
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
TiledCopyOAtom{},
|
||||
TiledCopyOThrLayout{},
|
||||
TiledCopyOValLayout{}
|
||||
));
|
||||
static constexpr int NumProducerThreads = cutlass::NumThreadsPerWarpGroup;
|
||||
static constexpr int NumMmaThreads = kNThreads - NumProducerThreads;
|
||||
static constexpr int kNumVecElem = ceil_div(128, sizeof_bits_v<output_type>);
|
||||
static constexpr int kNumThreadsPerRow = kHeadDim / kNumVecElem;
|
||||
static_assert(NumMmaThreads % kNumThreadsPerRow == 0);
|
||||
static constexpr int kNumRows = NumMmaThreads / kNumThreadsPerRow;
|
||||
using TiledCopyOAtom =
|
||||
cute::Copy_Atom<cute::UniversalCopy<cutlass::uint128_t>, output_type>;
|
||||
using TiledCopyOThrLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(Int<kNumRows>{}, Int<kNumThreadsPerRow>{}),
|
||||
LayoutRight{}));
|
||||
using TiledCopyOValLayout = decltype(cute::make_layout(
|
||||
cute::make_shape(_1{}, Int<kNumVecElem>{}), LayoutRight{}));
|
||||
using GmemTiledCopyO = decltype(make_tiled_copy(
|
||||
TiledCopyOAtom{}, TiledCopyOThrLayout{}, TiledCopyOValLayout{}));
|
||||
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
using MainloopPipeline = typename cutlass::PipelineTmaAsync<kStages>;
|
||||
using PipelineState = typename cutlass::PipelineState<kStages>;
|
||||
};
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
@@ -16,416 +26,559 @@
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
enum class AttnNamedBarriers {
|
||||
QueryEmpty = 0,
|
||||
ValueEmpty = 1,
|
||||
TileCountSmemEmpty = 2,
|
||||
TileCountSmemFull = 3,
|
||||
WarpSchedulerWG1 = 4,
|
||||
WarpSchedulerWG2 = 5,
|
||||
WarpSchedulerWG3 = 6,
|
||||
};
|
||||
|
||||
template <typename Ktraits>
|
||||
struct CollectiveMainloopAttn {
|
||||
using Element = typename Ktraits::Element;
|
||||
using output_type = typename Ktraits::output_type;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
|
||||
using Element = typename Ktraits::Element;
|
||||
using TileShape_MNK = typename Ktraits::TileShape_MNK;
|
||||
using ClusterShape = typename Ktraits::ClusterShape_MNK;
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
|
||||
static constexpr int kStages = Ktraits::kStages;
|
||||
static constexpr int kHeadDim = Ktraits::kHeadDim;
|
||||
static constexpr int kBlockM = Ktraits::kBlockM;
|
||||
static constexpr int kBlockN = Ktraits::kBlockN;
|
||||
static constexpr bool NeedMask = Ktraits::NeedMask;
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
|
||||
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
|
||||
using StrideT = cute::Shape<int32_t, _1, int32_t>;
|
||||
using LayoutT = cute::Layout<ShapeT, StrideT>;
|
||||
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyKV =
|
||||
decltype(cutlass::gemm::collective::detail::
|
||||
sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
|
||||
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
|
||||
|
||||
using SmemLayoutAtomQ =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ =
|
||||
decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
|
||||
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
|
||||
using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{})));
|
||||
using GmemTiledCopyO = typename Ktraits::GmemTiledCopyO;
|
||||
using SmemLayoutAtomK =
|
||||
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
|
||||
GMMA::Major::K,
|
||||
Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}),
|
||||
shape<2>(TileShape_MNK{}),
|
||||
Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt = decltype(cute::composition(
|
||||
SmemLayoutV{},
|
||||
make_layout(
|
||||
make_shape(
|
||||
get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}),
|
||||
_1{},
|
||||
Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
using SmemLayoutO = typename Ktraits::SmemLayoutO;
|
||||
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{})));
|
||||
using TMA_Q = decltype(make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}),
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{})); // no mcast for Q
|
||||
|
||||
using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
||||
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutK =
|
||||
decltype(tile_to_shape(SmemLayoutAtomK{},
|
||||
make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int<kStages>{})));
|
||||
using SmemLayoutV = SmemLayoutK;
|
||||
// Note this is the transpose in terms of the view, not in terms of memory.
|
||||
using SmemLayoutVt =
|
||||
decltype(cute::composition(SmemLayoutV{},
|
||||
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int<kStages>{}),
|
||||
make_stride(get<1>(TileShape_MNK{}), _1{}, Int<size(SmemLayoutV{}(_, _, _0{}))>{}))));
|
||||
using SmemLayoutO = typename Ktraits::SmemLayoutO;
|
||||
using SmemCopyAtomO = typename Ktraits::SmemCopyAtomO;
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
|
||||
using TMA_Q = decltype(make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{})); // no mcast for Q
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
|
||||
using TMA_KV = decltype(make_tma_copy(
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple
|
||||
// issues)
|
||||
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(
|
||||
size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(
|
||||
size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
LayoutT layout_Q;
|
||||
Element const* ptr_K;
|
||||
LayoutT layout_K;
|
||||
Element const* ptr_V;
|
||||
LayoutT layout_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutT layout_Q;
|
||||
LayoutT layout_K;
|
||||
LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
|
||||
TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{},
|
||||
mQ,
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{});
|
||||
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
|
||||
TMA_KV tma_load_K = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
make_tensor(
|
||||
make_gmem_ptr(static_cast<Element const*>(nullptr)),
|
||||
repeat_like(StrideT{}, int32_t(0)),
|
||||
StrideT{}
|
||||
),
|
||||
take<0, 2>(SmemLayoutK{}),
|
||||
mK,
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q,
|
||||
args.layout_K,
|
||||
args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()),
|
||||
get<2>(args.layout_K.shape()))),
|
||||
tma_load_Q,
|
||||
tma_load_K,
|
||||
tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
}
|
||||
|
||||
static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
|
||||
using MainloopPipeline = typename Ktraits::MainloopPipeline;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
using PipelineState = typename MainloopPipeline::PipelineState;
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
|
||||
/// performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_Q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_K.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_V.get_tma_descriptor());
|
||||
}
|
||||
|
||||
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
||||
static constexpr uint32_t TmaTransactionBytesQ = static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
static constexpr uint32_t TmaTransactionBytesK = static_cast<uint32_t>(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v<Element> / 8);
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(const MTensor& m_tensor,
|
||||
const Shape& tile_shape,
|
||||
const int* cu_seq_len,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int actual_seq_len) const {
|
||||
auto g_offset = local_tile(m_tensor(_, _, bidh),
|
||||
cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(cu_seq_len[bidb], _0{}));
|
||||
auto g_sequence = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(cute::make_shape(actual_seq_len, get<1>(tile_shape)),
|
||||
g_offset.stride()));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
}
|
||||
|
||||
static constexpr bool UseSchedulerBarrier = kHeadDim <= 128;
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage& shared_storage,
|
||||
const int n_block_max,
|
||||
const int m_block,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int* cu_seq_q,
|
||||
const int* cu_seq_k,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k) {
|
||||
Tensor sQ =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
Element const* ptr_Q;
|
||||
LayoutT layout_Q;
|
||||
Element const* ptr_K;
|
||||
LayoutT layout_K;
|
||||
Element const* ptr_V;
|
||||
LayoutT layout_V;
|
||||
float const softmax_scale_log2;
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(
|
||||
mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(
|
||||
mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(
|
||||
mainloop_params.layout_V.shape());
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
|
||||
Tensor gQ = get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(
|
||||
_, _, m_block);
|
||||
Tensor gK = get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor gV = get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
|
||||
Tensor sQ_x =
|
||||
make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x =
|
||||
make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q,
|
||||
_0{},
|
||||
Layout<_1>{},
|
||||
group_modes<0, 2>(sQ_x),
|
||||
group_modes<0, 2>(gQ_x));
|
||||
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K,
|
||||
_0{},
|
||||
Layout<_1>{},
|
||||
group_modes<0, 2>(sK),
|
||||
group_modes<0, 2>(gK));
|
||||
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V,
|
||||
_0{},
|
||||
Layout<_1>{},
|
||||
group_modes<0, 2>(sV),
|
||||
group_modes<0, 2>(gV));
|
||||
|
||||
uint16_t mcast_mask_kv = 0;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
copy(mainloop_params.tma_load_Q.with(
|
||||
reinterpret_cast<
|
||||
cutlass::arch::ClusterTransactionBarrier::ValueType&>(
|
||||
shared_storage.barrier_Q),
|
||||
0 /*mcast_mask*/),
|
||||
tQgQ,
|
||||
tQsQ);
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(
|
||||
*pipeline_k.producer_get_barrier(smem_pipe_write_k),
|
||||
mcast_mask_kv),
|
||||
tKgK(_, n_block),
|
||||
tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; --n_block) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(
|
||||
*pipeline_k.producer_get_barrier(smem_pipe_write_k),
|
||||
mcast_mask_kv),
|
||||
tKgK(_, n_block - 1),
|
||||
tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(
|
||||
*pipeline_v.producer_get_barrier(smem_pipe_write_v),
|
||||
mcast_mask_kv),
|
||||
tVgV(_, n_block),
|
||||
tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(
|
||||
*pipeline_v.producer_get_barrier(smem_pipe_write_v),
|
||||
mcast_mask_kv),
|
||||
tVgV(_, n_block),
|
||||
tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warp_scheduler_barrier_sync() {
|
||||
if constexpr (UseSchedulerBarrier) {
|
||||
cutlass::arch::NamedBarrier::sync(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
|
||||
cutlass::canonical_warp_group_idx() /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void mma_init() {
|
||||
if constexpr (!UseSchedulerBarrier) {
|
||||
return;
|
||||
}
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup ||
|
||||
NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if (cutlass::canonical_warp_group_idx() > 1) {
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/);
|
||||
}
|
||||
if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) {
|
||||
if (cutlass::canonical_warp_group_idx() > 2) {
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
|
||||
2 /*id*/);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void warp_scheduler_barrier_arrive() {
|
||||
if constexpr (!UseSchedulerBarrier) {
|
||||
return;
|
||||
}
|
||||
static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup ||
|
||||
NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup);
|
||||
if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) {
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
|
||||
(3 - cutlass::canonical_warp_group_idx()) /*id*/);
|
||||
} else {
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
|
||||
(cutlass::canonical_warp_group_idx() <= 2
|
||||
? cutlass::canonical_warp_group_idx() + 1
|
||||
: cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/);
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads,
|
||||
static_cast<int>(AttnNamedBarriers::WarpSchedulerWG1) - 1 +
|
||||
(cutlass::canonical_warp_group_idx() <= 1
|
||||
? cutlass::canonical_warp_group_idx() + 2
|
||||
: cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
|
||||
CUTLASS_DEVICE void mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_read_k,
|
||||
PipelineState& smem_pipe_read_v,
|
||||
FrgTensorO& tOrO,
|
||||
Softmax& softmax,
|
||||
const int* mask,
|
||||
const int n_block_max,
|
||||
const int thread_idx,
|
||||
const int m_block,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k,
|
||||
SharedStorage& shared_storage) {
|
||||
Tensor sQ =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()),
|
||||
SmemLayoutVt{});
|
||||
|
||||
typename Ktraits::TiledMma0 tiled_mma0;
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
|
||||
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMma0.partition_fragment_B(sK);
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
LayoutT layout_Q;
|
||||
LayoutT layout_K;
|
||||
LayoutT layout_V;
|
||||
cutlass::FastDivmod qhead_per_khead_divmod;
|
||||
TMA_Q tma_load_Q;
|
||||
TMA_KV tma_load_K, tma_load_V;
|
||||
float const softmax_scale_log2;
|
||||
};
|
||||
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
static Params
|
||||
to_underlying_arguments(Arguments const& args) {
|
||||
Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.layout_Q);
|
||||
TMA_Q tma_load_Q = make_tma_copy(
|
||||
GmemTiledCopyQ{},
|
||||
mQ,
|
||||
SmemLayoutQ{},
|
||||
select<0, 2>(TileShape_MNK{}),
|
||||
_1{});
|
||||
Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.layout_K);
|
||||
TMA_KV tma_load_K = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mK,
|
||||
SmemLayoutK{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.layout_V);
|
||||
TMA_KV tma_load_V = make_tma_copy(
|
||||
GmemTiledCopyKV{},
|
||||
mV,
|
||||
SmemLayoutV{}(_, _, _0{}),
|
||||
select<1, 2>(TileShape_MNK{}),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {args.layout_Q, args.layout_K, args.layout_V,
|
||||
cutlass::FastDivmod(cute::ceil_div(get<2>(args.layout_Q.shape()), get<2>(args.layout_K.shape()))),
|
||||
tma_load_Q, tma_load_K, tma_load_V,
|
||||
args.softmax_scale_log2};
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(
|
||||
shared_storage.barrier_Q.try_wait(0));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) {
|
||||
shared_storage.barrier_Q.wait(0);
|
||||
}
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor());
|
||||
Tensor tSrS =
|
||||
partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(
|
||||
tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<0>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k);
|
||||
++smem_pipe_read_k;
|
||||
|
||||
int mask_start_idx;
|
||||
int mask_row_id;
|
||||
int col_base;
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int lane_id = thread_idx % 32;
|
||||
mask_start_idx = mask[0] / kBlockN - 1;
|
||||
|
||||
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
|
||||
|
||||
col_base = thread_idx % 4 * 2;
|
||||
|
||||
app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN);
|
||||
} else {
|
||||
auto col_limit_causal = [&](int row, int n_block) {
|
||||
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q +
|
||||
m_block * kBlockM;
|
||||
};
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
Tensor tScS = threadMma0.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if (int(get<1>(tScS(i))) >=
|
||||
std::min(seq_len_k - n_block * kBlockN,
|
||||
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
|
||||
tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename MTensor, typename Shape>
|
||||
CUTLASS_DEVICE auto get_local_tile_tensor(
|
||||
const MTensor &m_tensor,
|
||||
const Shape &tile_shape,
|
||||
const int *cu_seq_len,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int actual_seq_len) const {
|
||||
auto g_offset = local_tile(
|
||||
m_tensor(_, _, bidh),
|
||||
cute::make_shape(1, get<1>(tile_shape)),
|
||||
make_coord(cu_seq_len[bidb], _0{}));
|
||||
auto g_sequence = make_tensor(
|
||||
g_offset.data(),
|
||||
make_layout(
|
||||
cute::make_shape(actual_seq_len, get<1>(tile_shape)),
|
||||
g_offset.stride()
|
||||
));
|
||||
auto g_tensor = local_tile(g_sequence, tile_shape, make_coord(_, _0{}));
|
||||
return g_tensor;
|
||||
softmax.template online_softmax</*Is_first=*/true>(
|
||||
tSrS, mainloop_params.softmax_scale_log2);
|
||||
|
||||
Tensor tOrP = make_tensor(
|
||||
convert_type<Element>(tSrS).data(),
|
||||
convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; --n_block) {
|
||||
Tensor tSrS =
|
||||
partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
warp_scheduler_barrier_sync();
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
if (n_block >= mask_start_idx) {
|
||||
app_mask(tSrS, mask, mask_row_id, col_base + n_block * kBlockN);
|
||||
}
|
||||
}
|
||||
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(
|
||||
tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
|
||||
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
warp_scheduler_barrier_arrive();
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
cute::copy(softmax.template max</*Is_first=*/false>(
|
||||
tSrS, mainloop_params.softmax_scale_log2),
|
||||
scores_scale);
|
||||
softmax.template online_softmax</*Is_first=*/false>(
|
||||
tSrS, mainloop_params.softmax_scale_log2);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(
|
||||
make_tensor(convert_type<Element>(tSrS).data(),
|
||||
convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(
|
||||
tSrS.layout())),
|
||||
tOrP);
|
||||
}
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
|
||||
template <typename SharedStorage>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_write_k,
|
||||
PipelineState& smem_pipe_write_v,
|
||||
SharedStorage &shared_storage,
|
||||
const int n_block_max,
|
||||
const int m_block,
|
||||
const int bidh,
|
||||
const int bidb,
|
||||
const int *cu_seq_q,
|
||||
const int *cu_seq_k,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k) {
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(
|
||||
tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2),
|
||||
scores_scale);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v);
|
||||
++smem_pipe_read_v;
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
}
|
||||
|
||||
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
|
||||
Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.layout_K.shape());
|
||||
Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.layout_V.shape());
|
||||
int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh);
|
||||
template <int NumMmaThreads,
|
||||
typename SharedStorage,
|
||||
typename FrgTensorO,
|
||||
typename TiledMma,
|
||||
typename T>
|
||||
CUTLASS_DEVICE void store(Params const& mainloop_params,
|
||||
FrgTensorO const& tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
const int o_head_stride,
|
||||
const int real_seq,
|
||||
T* out_ptr) {
|
||||
Tensor sO =
|
||||
make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor gQ = get_local_tile_tensor(
|
||||
mQ, select<0, 2>(TileShape_MNK{}), cu_seq_q, bidh, bidb, seq_len_q)(_, _, m_block);
|
||||
Tensor gK = get_local_tile_tensor(
|
||||
mK, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor gV = get_local_tile_tensor(
|
||||
mV, select<1, 2>(TileShape_MNK{}), cu_seq_k, bidh_kv, bidb, seq_len_k);
|
||||
Tensor tOrO_out = convert_type<output_type>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
|
||||
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
|
||||
auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{},group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x));
|
||||
auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, _0{}, Layout<_1>{},group_modes<0, 2>(sK), group_modes<0, 2>(gK));
|
||||
auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, _0{}, Layout<_1>{},group_modes<0, 2>(sV), group_modes<0, 2>(gV));
|
||||
cutlass::arch::NamedBarrier::sync(
|
||||
NumMmaThreads, static_cast<int>(AttnNamedBarriers::ValueEmpty) /*id*/);
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible
|
||||
// to TMA
|
||||
cutlass::arch::NamedBarrier::arrive(
|
||||
NumMmaThreads + cutlass::NumThreadsPerWarp,
|
||||
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||
|
||||
uint16_t mcast_mask_kv = 0;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(o_head_stride, _1{}));
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
if (lane_predicate) {
|
||||
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
|
||||
copy(mainloop_params.tma_load_Q.with(reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ);
|
||||
}
|
||||
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
|
||||
|
||||
if (lane_predicate) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
|
||||
tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
}
|
||||
|
||||
if (lane_predicate) {
|
||||
#pragma unroll 2
|
||||
for (; n_block > 0; --n_block) {
|
||||
pipeline_k.producer_acquire(smem_pipe_write_k);
|
||||
copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv),
|
||||
tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index()));
|
||||
++smem_pipe_write_k;
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
||||
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
}
|
||||
if (lane_predicate) {
|
||||
pipeline_v.producer_acquire(smem_pipe_write_v);
|
||||
copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv),
|
||||
tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index()));
|
||||
++smem_pipe_write_v;
|
||||
}
|
||||
if (real_seq >= kBlockM) {
|
||||
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
|
||||
}
|
||||
|
||||
template <typename SharedStorage, typename FrgTensorO, typename Softmax>
|
||||
CUTLASS_DEVICE void
|
||||
mma(Params const& mainloop_params,
|
||||
MainloopPipeline pipeline_k,
|
||||
MainloopPipeline pipeline_v,
|
||||
PipelineState& smem_pipe_read_k,
|
||||
PipelineState& smem_pipe_read_v,
|
||||
FrgTensorO& tOrO,
|
||||
Softmax& softmax,
|
||||
const int *mask,
|
||||
const int n_block_max,
|
||||
const int thread_idx,
|
||||
const int m_block,
|
||||
const int seq_len_q,
|
||||
const int seq_len_k,
|
||||
SharedStorage& shared_storage) {
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
|
||||
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
|
||||
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{});
|
||||
|
||||
typename Ktraits::TiledMma0 tiled_mma0;
|
||||
typename Ktraits::TiledMma1 tiled_mma1;
|
||||
auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx);
|
||||
auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
|
||||
Tensor tSrK = threadMma0.partition_fragment_B(sK);
|
||||
Tensor tOrV = threadMma1.partition_fragment_B(sVt);
|
||||
|
||||
auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
};
|
||||
|
||||
tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
|
||||
cutlass::ConsumerToken barrier_token = static_cast<cutlass::BarrierStatus>(shared_storage.barrier_Q.try_wait(0));
|
||||
if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(0); }
|
||||
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k);
|
||||
++smem_pipe_read_k;
|
||||
|
||||
int mask_start_idx;
|
||||
int mask_row_id;
|
||||
int col_base;
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int lane_id = thread_idx % 32;
|
||||
mask_start_idx = mask[0] / kBlockN - 1;
|
||||
|
||||
mask_row_id = thread_idx / 32 * 16 + lane_id / 4;
|
||||
|
||||
col_base = thread_idx % 4 * 2;
|
||||
|
||||
app_mask(
|
||||
tSrS,
|
||||
mask,
|
||||
mask_row_id,
|
||||
col_base + n_block * kBlockN);
|
||||
} else {
|
||||
auto col_limit_causal = [&](int row, int n_block) {
|
||||
return row + 1 + seq_len_k - n_block * kBlockN - seq_len_q + m_block * kBlockM;
|
||||
};
|
||||
Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
|
||||
Tensor tScS = threadMma0.partition_C(cS);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tSrS); ++i) {
|
||||
if (int(get<1>(tScS(i))) >=
|
||||
std::min(seq_len_k - n_block * kBlockN, col_limit_causal(int(get<0>(tScS(i))), n_block))) {
|
||||
tSrS(i) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
|
||||
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout()));
|
||||
Tensor scores_scale = make_fragment_like(softmax.row_max);
|
||||
clear(scores_scale);
|
||||
|
||||
#pragma unroll 1
|
||||
for (; n_block > 0; --n_block) {
|
||||
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
|
||||
consumer_wait(pipeline_k, smem_pipe_read_k);
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
if (n_block >= mask_start_idx) {
|
||||
app_mask(
|
||||
tSrS,
|
||||
mask,
|
||||
mask_row_id,
|
||||
col_base + n_block * kBlockN);
|
||||
}
|
||||
}
|
||||
|
||||
gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS);
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
warpgroup_wait<1>();
|
||||
pipeline_k.consumer_release(smem_pipe_read_k); // release K
|
||||
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
|
||||
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v); // release V
|
||||
++smem_pipe_read_k;
|
||||
++smem_pipe_read_v;
|
||||
cute::copy(make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs<typename Ktraits::TiledMma1>(tSrS.layout())), tOrP);
|
||||
}
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
consumer_wait(pipeline_v, smem_pipe_read_v);
|
||||
|
||||
gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO);
|
||||
cute::copy(softmax.finalize(mainloop_params.softmax_scale_log2), scores_scale);
|
||||
warpgroup_wait<0>();
|
||||
pipeline_v.consumer_release(smem_pipe_read_v);
|
||||
++smem_pipe_read_v;
|
||||
|
||||
softmax.rescale_o(tOrO, scores_scale);
|
||||
return;
|
||||
}
|
||||
|
||||
template <int NumMmaThreads, typename SharedStorage, typename FrgTensorO, typename TiledMma, typename T>
|
||||
CUTLASS_DEVICE void
|
||||
store(Params const& mainloop_params,
|
||||
FrgTensorO const& tOrO,
|
||||
SharedStorage& shared_storage,
|
||||
TiledMma tiled_mma,
|
||||
int thread_idx,
|
||||
const int o_head_stride,
|
||||
const int real_seq,
|
||||
T * out_ptr) {
|
||||
|
||||
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOrO_out = convert_type<Element>(tOrO);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out);
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
|
||||
cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0);
|
||||
|
||||
Tensor gO = make_tensor(make_gmem_ptr(out_ptr),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(o_head_stride, _1{}));
|
||||
|
||||
GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO);
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
Tensor cO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_S(cO);
|
||||
|
||||
if (real_seq >= kBlockM) {
|
||||
copy<true>(gmem_tiled_copy_O, tOsO, tOgO, tOcO);
|
||||
} else {
|
||||
copy<false>(gmem_tiled_copy_O, tOsO, tOgO, tOcO, real_seq);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -12,195 +22,245 @@
|
||||
|
||||
#include "utils.hpp"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
template<int THREADS>
|
||||
template <int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
template <typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
template <bool zero_init = true,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(
|
||||
Tensor<Engine0, Layout0> const &tensor,
|
||||
Tensor<Engine1, Layout1> &summary,
|
||||
Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
template <typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst,
|
||||
Tensor<Engine1, Layout1> &src,
|
||||
Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++) {
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
template <bool zero_init = true,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1,
|
||||
typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const &tensor,
|
||||
Tensor<Engine1, Layout1> &summary,
|
||||
Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
template <bool zero_init = true,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(
|
||||
Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &max) {
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, bool warp_reduce=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); }
|
||||
template <bool zero_init = true,
|
||||
bool warp_reduce = true,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(
|
||||
Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &sum) {
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
if constexpr (warp_reduce) {
|
||||
quad_allreduce_(sum, sum, sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
__forceinline__ __device__ __half2 half_exp(__half2 x) {
|
||||
uint32_t tmp_out, tmp_in;
|
||||
tmp_in = reinterpret_cast<uint32_t&>(x);
|
||||
asm ("ex2.approx.f16x2 %0, %1;\n"
|
||||
: "=r"(tmp_out)
|
||||
: "r"(tmp_in));
|
||||
__half2 out = reinterpret_cast<__half2&>(tmp_out);
|
||||
return out;
|
||||
uint32_t tmp_out, tmp_in;
|
||||
tmp_in = reinterpret_cast<uint32_t &>(x);
|
||||
asm("ex2.approx.f16x2 %0, %1;\n" : "=r"(tmp_out) : "r"(tmp_in));
|
||||
__half2 out = reinterpret_cast<__half2 &>(tmp_out);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=false, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
template <bool zero_init = false,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(
|
||||
Tensor<Engine0, Layout0> &tensor,
|
||||
Tensor<Engine1, Layout1> &max,
|
||||
Tensor<Engine1, Layout1> &sum,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const float max_scaled = max(mi) * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
template <typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(
|
||||
Tensor<Engine0, Layout0> &tensor,
|
||||
Tensor<Engine1, Layout1> const &max,
|
||||
const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const float max_scaled = max(mi) * scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
template <bool Is_first, bool Check_inf = false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s,
|
||||
float softmax_scale_log2) {
|
||||
Tensor scores =
|
||||
make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
scores_scale(mi) =
|
||||
exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = row_max(mi);
|
||||
scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
template <bool Is_first, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s,
|
||||
float softmax_scale_log2) {
|
||||
Tensor scores =
|
||||
make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<bool Is_first, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) {
|
||||
Tensor scores = make_tensor(acc_s.data(), convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
TensorT scores_scale;
|
||||
if constexpr (Is_first) {
|
||||
reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/true, /*warp_reduce=*/false>(scores, row_sum);
|
||||
cute::fill(scores_scale, 1.f);
|
||||
} else {
|
||||
scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
reduce_sum</*zero_init=*/false, /*warp_reduce=*/false>(scores, row_sum);
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT scores_scale;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.0f / sum;
|
||||
row_sum(mi) = row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
|
||||
scores_scale(mi) = inv_sum;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template<typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) {
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
__forceinline__ __device__ TensorT finalize(float softmax_scale_log2) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT scores_scale;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = 1.0f / sum;
|
||||
row_sum(mi) =
|
||||
row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum);
|
||||
scores_scale(mi) = inv_sum;
|
||||
}
|
||||
return scores_scale;
|
||||
};
|
||||
|
||||
template <typename Tensor1>
|
||||
__forceinline__ __device__ void rescale_o(Tensor1 &acc_o,
|
||||
TensorT const &scores_scale) {
|
||||
Tensor acc_o_rowcol =
|
||||
make_tensor(acc_o.data(), convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) {
|
||||
acc_o_rowcol(mi, ni) *= scores_scale(mi);
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,13 +1,23 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
||||
******************************************************************************/
|
||||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
@@ -19,8 +29,8 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <cute/arch/cluster_sm90.hpp> // For cute::elect_one_sync()
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
@@ -29,425 +39,468 @@
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct PackedHalf;
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct PackedHalf<cutlass::half_t> {
|
||||
using Type = __half2;
|
||||
using Type = __half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct PackedHalf<cutlass::bfloat16_t> {
|
||||
using Type = nv_bfloat162;
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
__forceinline__ __device__ auto float_2_half2(const float x) {
|
||||
if constexpr (std::is_same<T, cutlass::half_t>::value) {
|
||||
return __float2half2_rn(x);
|
||||
} else {
|
||||
return __float2bfloat162_rn(x);
|
||||
}
|
||||
if constexpr (std::is_same<T, cutlass::half_t>::value) {
|
||||
return __float2half2_rn(x);
|
||||
} else {
|
||||
return __float2bfloat162_rn(x);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
template<int BYTES>
|
||||
template <int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
template <typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
Alias_type data;
|
||||
|
||||
Alias_type data;
|
||||
inline __device__ Vec() {}
|
||||
|
||||
inline __device__ Vec() {}
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
template <typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT; it++) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
template <typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT; it++) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr) {
|
||||
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
|
||||
inline __device__ void load_from(const void *base_ptr) {
|
||||
this->data.vec = *reinterpret_cast<const Vec_type *>(base_ptr);
|
||||
}
|
||||
|
||||
inline __device__ void store_to(void *base_ptr) {
|
||||
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
|
||||
}
|
||||
|
||||
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline __device__ void store_to(void *base_ptr) {
|
||||
*reinterpret_cast<Vec_type *>(base_ptr) = this->data.vec;
|
||||
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale,
|
||||
const Vec<Elt_type, NUM_ELT> &bias) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
|
||||
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void add(const Vec<Elt_type, NUM_ELT> &other) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type b = *reinterpret_cast<const type *>(other.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void fma(const Vec<Elt_type, NUM_ELT> &scale, const Vec<Elt_type, NUM_ELT> &bias) {
|
||||
static_assert(NUM_ELT % 2 == 0);
|
||||
using type = typename PackedHalf<Elt_type>::Type;
|
||||
#pragma unroll
|
||||
for (int it = 0; it < NUM_ELT / 2; it++) {
|
||||
type a = *reinterpret_cast<const type *>(scale.data.elt + it * 2);
|
||||
type b = *reinterpret_cast<const type *>(bias.data.elt + it * 2);
|
||||
*reinterpret_cast<type *>(this->data.elt + it * 2) += a * b;
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void set_zero() {
|
||||
constexpr int size = sizeof(Vec_type) / sizeof(int);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; ++i) {
|
||||
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
|
||||
}
|
||||
inline __device__ void set_zero() {
|
||||
constexpr int size = sizeof(Vec_type) / sizeof(int);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size; ++i) {
|
||||
(reinterpret_cast<int *>(this->data.elt))[i] = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, int PackSize>
|
||||
inline __device__ void apply_rotary_embedding(Vec<T, PackSize>& vec, Vec<float, PackSize / 2>& cos, Vec<float, PackSize / 2>& sin) {
|
||||
static_assert(PackSize % 2 == 0);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < PackSize / 2; i++) {
|
||||
const float cos_inv_freq = cos.data.elt[i];
|
||||
const float sin_inv_freq = sin.data.elt[i];
|
||||
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
|
||||
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
|
||||
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
|
||||
vec.data.elt[2 * i + 1] = static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
|
||||
}
|
||||
template <typename T, int PackSize>
|
||||
inline __device__ void apply_rotary_embedding(Vec<T, PackSize> &vec,
|
||||
Vec<float, PackSize / 2> &cos,
|
||||
Vec<float, PackSize / 2> &sin) {
|
||||
static_assert(PackSize % 2 == 0);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < PackSize / 2; i++) {
|
||||
const float cos_inv_freq = cos.data.elt[i];
|
||||
const float sin_inv_freq = sin.data.elt[i];
|
||||
const float v1 = static_cast<float>(vec.data.elt[2 * i]);
|
||||
const float v2 = static_cast<float>(vec.data.elt[2 * i + 1]);
|
||||
vec.data.elt[2 * i] = static_cast<T>(cos_inv_freq * v1 - sin_inv_freq * v2);
|
||||
vec.data.elt[2 * i + 1] =
|
||||
static_cast<T>(sin_inv_freq * v1 + cos_inv_freq * v2);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
__forceinline__ __device__ void app_mask(
|
||||
Tensor &tSrS,
|
||||
const int *mask,
|
||||
const int &mask_row_id,
|
||||
const int &col_base) {
|
||||
const float mask_value = -1000000.0f;
|
||||
for (int i = 0; i < size(tSrS); i+=8) {
|
||||
const int col = i * 2 + col_base;
|
||||
if (col >= mask[mask_row_id]) {
|
||||
tSrS(i) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id]) {
|
||||
tSrS(i + 1) = mask_value;
|
||||
}
|
||||
if (col >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 2) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 3) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id]) {
|
||||
tSrS(i + 4) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id]) {
|
||||
tSrS(i + 5) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 6) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 7) = mask_value;
|
||||
}
|
||||
__forceinline__ __device__ void app_mask(Tensor &tSrS,
|
||||
const int *mask,
|
||||
const int &mask_row_id,
|
||||
const int &col_base) {
|
||||
const float mask_value = -1000000.0f;
|
||||
for (int i = 0; i < size(tSrS); i += 8) {
|
||||
const int col = i * 2 + col_base;
|
||||
if (col >= mask[mask_row_id]) {
|
||||
tSrS(i) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id]) {
|
||||
tSrS(i + 1) = mask_value;
|
||||
}
|
||||
if (col >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 2) = mask_value;
|
||||
}
|
||||
if (col + 1 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 3) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id]) {
|
||||
tSrS(i + 4) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id]) {
|
||||
tSrS(i + 5) = mask_value;
|
||||
}
|
||||
if (col + 8 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 6) = mask_value;
|
||||
}
|
||||
if (col + 9 >= mask[mask_row_id + 8]) {
|
||||
tSrS(i + 7) = mask_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct HalfMax;
|
||||
template<>
|
||||
template <>
|
||||
struct HalfMax<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&res))
|
||||
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t *>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct HalfMax<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x,
|
||||
const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&res))
|
||||
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t *>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct HalfMin;
|
||||
template<>
|
||||
template <>
|
||||
struct HalfMin<cutlass::half_t> {
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("min.f16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
inline __device__ __half2 operator()(const __half2 x, const __half2 y) {
|
||||
__half2 res;
|
||||
asm volatile("min.f16x2 %0, %1, %2;\n"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&res))
|
||||
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t *>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
template <>
|
||||
struct HalfMin<cutlass::bfloat16_t> {
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x, const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("min.bf16x2 %0, %1, %2;\n" :
|
||||
"=r"(*reinterpret_cast<uint32_t*>(&res)) :
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t*>(&y)));
|
||||
return res;
|
||||
}
|
||||
inline __device__ nv_bfloat162 operator()(const nv_bfloat162 x,
|
||||
const nv_bfloat162 y) {
|
||||
nv_bfloat162 res;
|
||||
asm volatile("min.bf16x2 %0, %1, %2;\n"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&res))
|
||||
: "r"(*reinterpret_cast<const uint32_t *>(&x)),
|
||||
"r"(*reinterpret_cast<const uint32_t *>(&y)));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
template <bool Is_even_MN=true, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
||||
template <bool Is_even_MN = true,
|
||||
typename TiledCopy,
|
||||
typename Engine0,
|
||||
typename Layout0,
|
||||
typename Engine1,
|
||||
typename Layout1,
|
||||
typename Engine2,
|
||||
typename Layout2>
|
||||
__forceinline__ __device__ void copy(
|
||||
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &identity_MN,
|
||||
const int max_MN = 0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
TiledCopy tiled_copy,
|
||||
Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &identity_MN,
|
||||
const int max_MN = 0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
auto frag =
|
||||
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
|
||||
tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
template<typename T, typename ReductionOp, int block_size>
|
||||
template <typename T, typename ReductionOp, int block_size>
|
||||
__inline__ __device__ T BlockAllReduce(T val) {
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
|
||||
if (threadIdx.x == 0) { result_broadcast = result; }
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
typedef cub::BlockReduce<T, block_size> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ T result_broadcast;
|
||||
T result = BlockReduce(temp_storage).Reduce(val, ReductionOp());
|
||||
if (threadIdx.x == 0) {
|
||||
result_broadcast = result;
|
||||
}
|
||||
__syncthreads();
|
||||
return result_broadcast;
|
||||
}
|
||||
|
||||
template<typename T, int block_size>
|
||||
template <typename T, int block_size>
|
||||
__inline__ __device__ T BlockScanSum(T val) {
|
||||
typedef cub::BlockScan<int, block_size> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
typedef cub::BlockScan<int, block_size> BlockScanT;
|
||||
__shared__ typename BlockScanT::TempStorage temp_storage;
|
||||
|
||||
int aggregate;
|
||||
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
|
||||
__syncthreads();
|
||||
return val;
|
||||
int aggregate;
|
||||
BlockScanT(temp_storage).ExclusiveSum(val, val, aggregate);
|
||||
__syncthreads();
|
||||
return val;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
__device__ __forceinline__ T operator()(T const &x, T const &y) {
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) {
|
||||
return max(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct MinOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x < y ? x : y; }
|
||||
__device__ __forceinline__ T operator()(T const &x, T const &y) {
|
||||
return x < y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MinOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return min(x, y); }
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) {
|
||||
return min(x, y);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
template <typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
__device__ __forceinline__ T operator()(T const &x, T const &y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename MMA_traits, typename Layout>
|
||||
template <typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout), Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
using X = Underscore;
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
|
||||
auto l = logical_divide(get<0>(acc_layout),
|
||||
Shape<X, X, _2>{}); // (2, 2, (2, N / 16)))
|
||||
return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)),
|
||||
get<1>(acc_layout),
|
||||
make_layout(get<2, 1>(l), get<2>(acc_layout)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
|
||||
}
|
||||
|
||||
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
auto l = logical_divide(
|
||||
acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(
|
||||
make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, typename ReductionOp, int thread_group_width = 32>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
ReductionOp op;
|
||||
#pragma unroll
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
template <bool zero_init = false,
|
||||
int wg_wait = 0,
|
||||
bool arrive = true,
|
||||
bool commit = true,
|
||||
typename Tensor0,
|
||||
typename Tensor1,
|
||||
typename Tensor2,
|
||||
typename TiledMma>
|
||||
__forceinline__ __device__ void gemm(TiledMma &tiled_mma,
|
||||
Tensor0 const &tCrA,
|
||||
Tensor1 const &tCrB,
|
||||
Tensor2 &tCrC) {
|
||||
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeA>::value;
|
||||
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take
|
||||
// const
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (arrive) {
|
||||
warpgroup_arrive();
|
||||
}
|
||||
if constexpr (zero_init) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
return val;
|
||||
} else {
|
||||
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block), tCrB(_, _, k_block), tCrC);
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
}
|
||||
if constexpr (commit) {
|
||||
warpgroup_commit_batch();
|
||||
}
|
||||
if constexpr (wg_wait >= 0) {
|
||||
warpgroup_wait<wg_wait>();
|
||||
}
|
||||
warpgroup_fence_operand(tCrC);
|
||||
if constexpr (Is_RS) {
|
||||
warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
|
||||
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
|
||||
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = acc_layout;
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
|
||||
} else { // SM80
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)),
|
||||
make_layout(get<0, 0>(l), get<2>(l)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename ReductionOp, int thread_group_width = 32>
|
||||
__inline__ __device__ T WarpAllReduce(T val) {
|
||||
ReductionOp op;
|
||||
#pragma unroll
|
||||
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
|
||||
val = op(val, __shfl_xor_sync(0xffffffff, val, mask));
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user