Sync v2.0 version of code to github repo

This commit is contained in:
Jiang-Jia-Jun
2025-06-29 23:29:37 +00:00
parent d151496038
commit 92c2cfa2e7
597 changed files with 78776 additions and 22905 deletions

View File

@@ -0,0 +1,250 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Architecture-specific operators on memory added for SM80
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/complex.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/arch/cache_operation.h"
namespace cutlass {
namespace arch {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Initiates an asynchronous copy from global memory to shared memory.
///
/// cp.async
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always,
bool GlobalToShared = true>
struct copy;
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
///
/// cp.async
///
template <
/// Size of the access in bytes
int SizeInBytes,
/// Cache operation
CacheOperation::Kind cache_op = CacheOperation::Always,
bool GlobalToShared = true>
struct copy_zfill;
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
///
/// cp.async
///
template <int N, bool GlobalToShared = true>
struct copy_wait;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Always, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, true> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy<SizeInBytes, CacheOperation::Global, false> {
/// Copy
CUTLASS_DEVICE
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
}
};
/// Partial specialization
template <
/// Size of the access in bytes
int SizeInBytes>
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
/// Copy with zero fill
CUTLASS_DEVICE
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, SizeInBytes>;
if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
template <bool GlobalToShared>
CUTLASS_DEVICE
void copy_fence() {}
template <>
CUTLASS_DEVICE
void copy_fence<true>() {
cp_async_fence();
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization
template <int N>
struct copy_wait<N, false> {
CUTLASS_DEVICE
copy_wait() {}
};
/// Partial specialization
template <int N>
struct copy_wait<N, true> {
CUTLASS_DEVICE
copy_wait() { cp_async_wait<N>(); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace arch
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,460 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
const Element* const* ptr_row_array = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
int group, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, group(group)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
int group;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row_array[group]));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
l,
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcastArray {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
const Element* const* ptr_col_array = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
int group,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
group(group),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
int group;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col_array[group]));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
l,
params
);
}
};
}

View File

@@ -0,0 +1,500 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either
// row/column or scalar broadcasting where the tensor being loaded from is
// always passed in via a device pointer. This lets one compiled kernel handle
// all cases of per-tensor or per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graph
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cute/tensor.hpp"
namespace cutlass::epilogue::threadblock {
using namespace cute;
using namespace detail;
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->row_broadcast) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are loading from a scalar and broadcasting
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
template<
class ThreadMap,
class Element,
class StrideMNL
>
struct VisitorRowOrZeroBroadcast {
// This struct has been modified to remove null_default (because it's always 0)
struct Arguments {
Element const* ptr_row = nullptr;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage {};
// Global load type
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gRow,
RTensor&& tC_rRow,
CTensor&& tC_cRow,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gRow(cute::forward<GTensor>(tC_gRow)),
tC_rRow(cute::forward<RTensor>(tC_rRow)),
tC_cRow(cute::forward<CTensor>(tC_cRow)),
n(get<1>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gRow;
RTensor tC_rRow;
CTensor tC_cRow;
Params const* params_ptr;
int n;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rRow);
auto src_v = filter(tC_gRow);
auto coord_v = filter(tC_cRow);
auto dst_v = filter(tC_rRow);
if (params_ptr->ptr_row != nullptr) {
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
bool guard = get<1>(coord_v(i)) < n;
cutlass::arch::global_load<VecType, sizeof(VecType)>(
dst_v(i), (void const*)&src_v(i), guard);
}
} else {
// In this case we are broadcasting 0
VecType filled_vec;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < VecLength; i++) {
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(src_v); ++i) {
if (get<1>(coord_v(i)) < n) {
dst_v(i) = filled_vec;
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
return rRow_frg(column_idx);
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mRow = make_tensor(
make_gmem_ptr(params_ptr->ptr_row),
problem_shape,
params_ptr->dRow);
// VECTOR, FRAGMENT_COLUMN
Tensor tC_gRow = recast<VecType>(
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
)(_,_,_0{},_0{},_0{},_0{});
Tensor tC_rRow = make_tensor_like(tC_gRow);
// Generate the pred tensor
Tensor cRow = make_identity_tensor(mRow.shape());
Tensor tC_cRow = outer_partition(
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
Shape<Int<VecLength>>{},
(_0{})
);
return Callbacks<
decltype(tC_gRow), decltype(tC_rRow),
decltype(tC_cRow), ProblemShape>(
cute::move(tC_gRow),
cute::move(tC_rRow),
cute::move(tC_cRow),
problem_shape,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
class ThreadMap,
class Element,
class StrideMNL = Stride<_1,_0,_0>
>
struct VisitorColOrScalarBroadcast {
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
struct SharedStorage { };
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params_ptr(&params) { }
Params const* params_ptr;
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
struct Callbacks : EmptyCallbacks {
CUTLASS_DEVICE
Callbacks(
GTensor&& tC_gCol,
RTensor&& tC_rCol,
CTensor&& tC_cCol,
ProblemShape problem_shape,
Params const* params_ptr
):
tC_gCol(cute::forward<GTensor>(tC_gCol)),
tC_rCol(cute::forward<RTensor>(tC_rCol)),
tC_cCol(cute::forward<CTensor>(tC_cCol)),
m(get<0>(problem_shape)),
params_ptr(params_ptr) { }
GTensor tC_gCol;
RTensor tC_rCol;
CTensor tC_cCol;
Params const* params_ptr;
int m;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE void
begin_epilogue() {
clear(tC_rCol);
Tensor pred = make_tensor<bool>(shape(tC_gCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tC_cCol(i)) < m;
}
if (params_ptr->col_broadcast) {
// In this case we are loading from a column vector and broadcasting
copy_if(pred, tC_gCol, tC_rCol);
} else {
// In this case we are loading from a scalar and broadcasting
auto dst_v = filter(tC_rCol);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(dst_v); ++i) {
if (pred(i)) {
dst_v(i) = *(params_ptr->ptr_col);
}
}
}
}
template <class ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE auto // returns an Array
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
Array<Element, FragmentSize> frg_col;
frg_col.fill(tC_rCol(row_idx,iter_idx));
return frg_col;
}
};
template <class ProblemShape>
CUTLASS_DEVICE auto
get_callbacks(
gemm::GemmCoord threadblock_tile_offset,
int thread_idx,
ProblemShape problem_shape
) {
Tensor mCol = make_tensor(
make_gmem_ptr(params_ptr->ptr_col),
problem_shape,
params_ptr->dCol);
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
Tensor tC_gCol = group_modes<1,4>(
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
Tensor tC_rCol = make_tensor_like(tC_gCol);
// Generate the pred tensor
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tC_cCol = group_modes<1,4>(
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
return Callbacks<
decltype(tC_gCol), decltype(tC_rCol),
decltype(tC_cCol), ProblemShape>(
cute::move(tC_gCol),
cute::move(tC_rCol),
cute::move(tC_cCol),
problem_shape,
params_ptr
);
}
};
}

View File

@@ -0,0 +1,450 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
// from https://github.com/NVIDIA/cutlass v3.5.0
// It has been modified to support either row/column or scalar broadcasting
// where the tensor being loaded from is always passed in via a device pointer.
// This lets one compiled kernel handle all cases of per-tensor or
// per-channel/per-token quantization.
//
// This interface also allows the scales to be passed in as tensors that
// consistently reside on the device, which avoids an issue with a previous
// implementation where scalars needed to be on the CPU since they
// were passed in via float values. This created a potential performance hazard
// if scales were initially on the device, and caused torch.compile graphs
// breaks when moving scales to the CPU.
//
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
#pragma once
// Turn off clang-format for the entire file to keep it close to upstream
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_row is null.
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (!params.row_broadcast) {
fill(tSR_rRow, *(params.ptr_row));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
// This struct has been modified to have a bool indicating that ptr_col is a
// scalar that must be broadcast, instead of containing a scalar that is
// valid if ptr_col is null.
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
}
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast() { }
CUTLASS_HOST_DEVICE
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
params
);
}
};
}

View File

@@ -0,0 +1,327 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
#pragma once
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy::c2x {
using namespace cute;
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
template <typename T>
using ColOrScalarLoad =
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad =
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
template <typename T>
using RowOrZeroLoad =
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::Tensor const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr = static_cast<T *>(const_cast<void *>(
tensor.data()));
if constexpr (std::is_same_v<Descriptor,
ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor,
RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
}
else {
// it would technically work but no use case as data_ptr is never nullptr
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto *data_ptr =
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
return Arguments{data_ptr};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
paddle._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
protected:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzp
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBiasAzpToken
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
}; // namespace fastdeploy::c2x

View File

@@ -0,0 +1,453 @@
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
#pragma once
// clang-format will break include orders
// clang-format off
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// clang-format on
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace fastdeploy::c3x {
using namespace cute;
template <typename T> struct identity {
CUTLASS_HOST_DEVICE
T operator()(T lhs) const { return lhs; }
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using Compute = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
using ArgumentType = typename EVTCompute::Arguments;
template <typename... Args> static ArgumentType prepare_args(Args... args) {
return {};
}
};
/*
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
template <typename T>
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
// Don't want to support nullptr by default
template <typename T, bool EnableNullPtr = false>
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
128 / sizeof_bits_v<T>, EnableNullPtr>;
template <typename T>
using ColOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
template <typename T>
using RowOrScalarLoadArray =
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::Tensor const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr = static_cast<T *>(const_cast<void *>(tensor.data()));
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
return Arguments{data_ptr, tensor.numel() != 1};
} else {
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
!std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
}
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
using Arguments = typename Descriptor::Arguments;
auto *data_ptr =
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
std::is_same_v<Descriptor, RowLoad<T, true>>);
return Arguments{data_ptr};
}
template <typename Descriptor, typename T>
static auto args_from_tensor(const T *const *data_ptr, bool do_broadcast) {
using Arguments = typename Descriptor::Arguments;
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
return Arguments{data_ptr, do_broadcast};
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
paddle.scaled_mm_.
A and B may be both either int8 or fp8_e4m3. A can be
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
* This bias can also be used in the per-tensor azp case, where the activation
* zero point (azp) is used to compute an azp correction term,
* which is folded into the bias.
*
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue performs the same operation as ScaledEpilogueBias, but the
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template ColLoad<ElementD>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
paddle::Tensor const &b_scales,
paddle::Tensor const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, bias_args, {}};
}
};
/*
* This epilogue directly supports per-tensor azp in int32 form.
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
* term, which should already be multiplied with the scalar azp.
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// This is the full AZP term, azp * J @ B, shape (1,n)
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
// Compute float(accum - azp_adj), both operands are int32_t
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_azp_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
* This epilogue supports per-token azp by computing and applying
* the correction term using a rank-1 update. If the term were materialized,
* it would require O(m*n) space, and this way it only requires O(m+n) space.
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
* point for each row of A.
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
*
* This epilogue also supports bias, which remains per-channel.
*/
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
using Bias = typename SUPER::template RowLoad<ElementD, true>;
// Per-token azp term, shape (m,1)
using Azp = typename SUPER::template ColLoad<int32_t>;
// This is the AZP adjustment term, J @ B, shape (1,n)
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
// Compute azp * azp_adj
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, int32_t, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAzp =
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
// Compute float(accum - azp*azp_adj), all operands are int32_t
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
cutlass::minus, float, int32_t,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeAcc =
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTComputeScaleB =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
paddle::optional<paddle::Tensor> const &bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
auto azp_adj_args =
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
typename EVTComputeScaleB::Arguments evt_scale_b_args{
b_args, evt_acc_args, {}};
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
}
};
/*
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
to arrays containing different scales used in group gemm. The number of
pointers in ScaleA and the number of pointers in ScaleB are equal to the
group size.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueArray
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
static ArgumentType prepare_args(float const *const *a_scales_ptr,
float const *const *b_scales_ptr,
bool a_col_broadcast, bool b_row_broadcast) {
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
a_scales_ptr, a_col_broadcast);
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
b_scales_ptr, b_row_broadcast);
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
return ArgumentType{a_args, evt0_args, {}};
}
};
}; // namespace fastdeploy::c3x

View File

@@ -0,0 +1,284 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
// Returns the maximum number of smem tiles that can be used with a given smem
// capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK,
bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(
StageCountAutoCarveout<carveout_bytes> stage_count) {
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int {
if constexpr (SwapAB) {
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
8 +
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
stage_barrier_bytes;
} else {
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
8 +
stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType,
template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
cute::enable_if_t<
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB,
GmemLayoutB>()>> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm =
(cute::is_same_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only "
"compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>,
tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>,
tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA =
detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB =
detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative> ||
IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator,
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA =
decltype(detail::ss_smem_selector<
GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB =
decltype(detail::ss_smem_selector<
GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages =
detail::compute_stage_count_or_override_gated<
detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<
IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<
PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType,
template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
cute::enable_if_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedFP8FastAccum> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
cute::is_same_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(),
"Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(
!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA =
detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB =
detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm =
(cute::is_same_v<
KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator,
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA =
decltype(detail::ss_smem_selector<
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB =
decltype(detail::ss_smem_selector<
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages =
detail::compute_stage_count_or_override_gated<
detail::sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK,
SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<
IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,60 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType,
template <class /* ElementCompute */> class Activation,
bool SwapAB = false, class Enable = void>
struct CollectiveBuilderGated {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,62 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA,
class ElementB, class StrideB, class TiledMma, class GmemTiledCopyA,
class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB,
class TransformB,
template <class /* ElementCompute */> class Activation,
bool SwapAB = false>
struct CollectiveMmaGated {
static_assert(cutlass::detail::dependent_false<ElementA>,
"Could not find a mainloop specialization.");
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,713 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule,
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_,
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
SwapAB_> {
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy =
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 2 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for "
"this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any
// rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA =
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB =
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementB>>>;
using InternalElementAux =
cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA,
cute::cosize_v<SmemLayoutA>>
smem_A;
cute::array_aligned<typename TiledMma::ValTypeB,
cute::cosize_v<SmemLayoutB>>
smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const *ptr_A;
StrideA dA;
ElementB const *ptr_B0;
ElementB const *ptr_B1;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<InternalElementA const *>(nullptr),
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<InternalElementB const *>(nullptr),
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const &problem_shape,
Arguments const &args, void *workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const *>(args.ptr_A);
auto ptr_B0 = reinterpret_cast<InternalElementB const *>(args.ptr_B0);
Tensor tensor_a =
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b =
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB) {
auto ptr_Aux = reinterpret_cast<InternalElementA const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(
ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
args.scale_d1};
} else {
auto ptr_Aux = reinterpret_cast<InternalElementB const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(
ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
args.scale_d1};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const &problem_shape,
[[maybe_unused]] Arguments const &args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A =
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B =
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
cute::make_shape(N, K, L), StrideB{});
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
"minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
8 +
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
8 +
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const &mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the
/// contract Returned tuple must contain at least two elements, with the first
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) The rest of the
/// tensors can be specified as needed by this collective.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
Params const &mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain
// mapping Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB) {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
} else {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
class BlockCoord>
CUTLASS_DEVICE void
load(Params const &mainloop_params, MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage &shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x =
get<0>(typename DispatchPolicy::ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a =
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b =
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux =
SwapAB
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(
cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
: gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(
m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB) {
mcast_mask_aux = mcast_mask_a;
} else {
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
TensorStorage &shared_tensors, Params const &mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAux{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.partition_A(sAux);
} else {
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.make_fragment_A(tCsAux);
} else {
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB) {
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
} else {
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
--k_tile_prologue) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum1);
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
++smem_pipe_read;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum0);
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accum1);
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accum1);
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
/// ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline.consumer_release(smem_pipe_release);
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
warpgroup_fence_operand(accum0);
warpgroup_fence_operand(accum1);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,724 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template <int Stages, class ClusterShape, class KernelSchedule,
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
class TransformB_,
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
struct CollectiveMmaGated<
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
SwapAB_> {
static constexpr bool isGated = true;
static constexpr bool SwapAB = SwapAB_;
//
// Type Aliases
//
using DispatchPolicy =
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape,
KernelSchedule>;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;
using Activation = Activation_<ElementAccumulator>;
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
typename TiledMma::ValTypeB>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 1 or more.");
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source both A and B operand from smem_desc for "
"this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
cute::array_aligned<typename TiledMma::ValTypeA,
cute::cosize_v<SmemLayoutA>>
smem_A;
cute::array_aligned<typename TiledMma::ValTypeB,
cute::cosize_v<SmemLayoutB>>
smem_B;
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const *ptr_A;
StrideA dA;
ElementB const *ptr_B0;
ElementB const *ptr_B1;
StrideB dB;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
// Device side kernel params
struct Params {
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy(
GmemTiledCopyA{},
make_tensor(static_cast<ElementA const *>(nullptr),
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
SmemLayoutA{}(_, _, 0),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using TMA_B = decltype(make_tma_copy(
GmemTiledCopyB{},
make_tensor(static_cast<ElementB const *>(nullptr),
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
SmemLayoutB{}(_, _, 0),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Aux tma_load_aux;
float scale_d0 = 1.0f;
float scale_d1 = 1.0f;
uint32_t mma_promotion_interval = 4;
};
//
// Methods
//
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const &problem_shape,
Arguments const &args, void *workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<ElementA const *>(args.ptr_A);
auto ptr_B0 = reinterpret_cast<ElementB const *>(args.ptr_B0);
Tensor tensor_a =
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
Tensor tensor_b =
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_A tma_load_a = make_tma_copy(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
typename Params::TMA_B tma_load_b = make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
if constexpr (SwapAB) {
auto ptr_Aux = reinterpret_cast<ElementA const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
size<1>(
ClusterShape{})); // mcast along N mode for this M load, if any
return {tma_load_a, tma_load_b, tma_load_aux,
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
} else {
auto ptr_Aux = reinterpret_cast<ElementB const *>(args.ptr_B1);
Tensor tensor_aux =
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(
ClusterShape{})); // mcast along M mode for this N load, if any
return {tma_load_a, tma_load_b, tma_load_aux,
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
}
}
template <class ProblemShape>
static bool can_implement(ProblemShape const &problem_shape,
[[maybe_unused]] Arguments const &args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A =
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
cute::make_shape(M, K, L), StrideA{});
constexpr int min_tma_aligned_elements_B =
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable =
implementable &&
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
cute::make_shape(N, K, L), StrideB{});
/* MMA promotion interval should be a multiple of 4, since each mainloop
* iteration would issue 4 MMA instructions. */
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
"minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr int K_PIPE_MMAS = 1;
static constexpr uint32_t TmaTransactionBytes =
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
8 +
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
8 +
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
8;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const &mainloop_params) {
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_b.get_tma_descriptor());
cute::prefetch_tma_descriptor(
mainloop_params.tma_load_aux.get_tma_descriptor());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the
/// contract Returned tuple must contain at least two elements, with the first
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
Params const &mainloop_params) const {
using X = Underscore;
// Separate out problem shape for convenience
auto [M, N, K, L] = problem_shape_MNKL;
// TMA requires special handling of strides to deal with coord codomain
// mapping Represent the full tensors -- get these from TMA
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
// Make tiled views, defer the slice
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
if constexpr (SwapAB) {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(M, K, L)); // (m,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
} else {
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
make_shape(N, K, L)); // (n,k,l)
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
class BlockCoord>
CUTLASS_DEVICE void
load(Params const &mainloop_params, MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
int thread_idx, uint32_t block_rank_in_cluster,
TensorStorage &shared_tensors) {
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Prepare the TMA loads for A and B
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
auto block_tma_a =
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b =
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
auto block_tma_aux =
SwapAB
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
: mainloop_params.tma_load_aux.get_slice(
cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
: gAux_xkl(_, _, n_coord, _, l_coord);
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_aux = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
n, Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout =
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
// block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(
m, cluster_local_block_id.y, Int<0>{}));
}
}
if constexpr (SwapAB) {
mcast_mask_aux = mcast_mask_a;
} else {
mcast_mask_aux = mcast_mask_b;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
TensorStorage &shared_tensors, Params const &mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutA{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
SmemLayoutAux{});
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
auto tCsAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.partition_A(sAux);
} else {
return thread_mma.partition_B(sAux);
}
}();
auto tCrAux = [&]() -> auto {
if constexpr (SwapAB) {
return thread_mma.make_fragment_A(tCsAux);
} else {
return thread_mma.make_fragment_B(tCsAux);
}
}();
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
if constexpr (SwapAB) {
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
} else {
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
}
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
size<2>(sAux)); // PIPE
//
// PIPELINED MAIN LOOP
//
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
"ERROR : Incorrect number of MMAs in flight");
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
GmmaFP8Accumulation accumulation0(
accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
GmmaFP8Accumulation accumulation1(
accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
CUTLASS_PRAGMA_UNROLL
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
--k_tile_prologue) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
if (accumulation0.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
int read_stage = smem_pipe_read.index();
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation1());
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
++smem_pipe_read;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
// Mainloop GMMAs
k_tile_count -= prologue_mma_count;
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 0; --k_tile_count) {
// WAIT on smem_pipe_read until its data are available (phase bit flips
// from rdPhaseBit value)
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
if (accumulation0.prepare_if_needed()) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
warpgroup_arrive();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// (V,M,K) x (V,N,K) => (V,M,N)
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation0());
if constexpr (SwapAB) {
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
tCrB(_, _, k_block, read_stage), accumulation1());
} else {
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
tCrAux(_, _, k_block, read_stage), accumulation1());
}
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
/// ensure smem_pipe_write is consumed
warpgroup_wait<K_PIPE_MMAS>();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
accumulation0.promote_if_needed();
accumulation1.promote_if_needed();
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++smem_pipe_read;
++smem_pipe_release;
}
accumulation0.promote_residue_if_needed();
accumulation1.promote_residue_if_needed();
warpgroup_fence_operand(accumulation0());
warpgroup_fence_operand(accumulation1());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
// done _computing_ on it
++smem_pipe_release;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,71 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
////////////////////////////////////////////////////////////////////////////////
/*
* Stateless universal device GEMM kernel type that treats GEMM as
* a composition of a collective mainloop and a collective epilogue.
*
* Supports both the 2.x and 3.x APIs based on whether the first type is
* a cute::tuple<> or not.
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
*
* In the following declaration, the name preceding the 'Or' refers to
* 3.x API type argument order, and the name succeeding the 'Or' refers to
* 2.x API type argument order. Template arguments without two names
* belong to the 3.x API only.
**/
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
class CollectiveMainloopOrEpilogue_,
class CollectiveEpilogueOrThreadblockSwizzle_,
class TileScheduler_ = void, class Enable = void>
class GemmUniversalGated;
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel
////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
////////////////////////////////////////////////////////////////////////////////

View File

@@ -130,6 +130,15 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{

View File

@@ -0,0 +1,705 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/tensor.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_,
class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedCooperative,
typename CollectiveMainloop_::
DispatchPolicy::Schedule> &&
CollectiveMainloop_::isGated>> {
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or
cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(ArchTag::kMinComputeCapability >= 90);
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups =
CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t MaxThreadsPerBlock =
CUTE_STATIC_V(size(TiledMma{})) +
(NumLoadWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load
// threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Kernel level shared memory storage
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage =
typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage =
typename CollectiveEpilogue::PipelineStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
void *workspace{nullptr};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the
// aliased type.
static Params to_underlying_arguments(Arguments const &args,
void *workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments "
"KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
args.hw_info.device_id);
}
CUTLASS_TRACE_HOST(
"to_underlying_arguments(): Setting persistent grid SM count to "
<< sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
void *scheduler_workspace = workspace_ptr;
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *mainloop_workspace = nullptr;
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler.
// Therefore it will be used in separate reduction scheme for streamk case,
// NumEpilogueSubTiles default value is 1, which means subtile will not be
// used, therefore separate reduction will not be enabled.
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
return {args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(
args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(
args.problem_shape, args.epilogue, epilogue_workspace),
hw_info,
scheduler,
workspace};
}
static bool can_implement(Arguments const &args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched &&
cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
"meet the requirements.\n");
return implementable;
}
implementable &=
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &=
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_size = 0;
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
workspace_size +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const &args, void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
constexpr uint32_t NumEpilogueSubTiles =
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
status = TileScheduler::template initialize_workspace<ProblemShape,
ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream,
args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
NumEpilogueSubTiles);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const &params) {
// Given device SM count, set grid size s.t. we do not launch more thread
// blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order =
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
TileShape{}, ClusterShape{},
params.hw_info, args);
}
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
CUTLASS_DEVICE
void operator()(Params const &params, char *smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting "
"sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(
size(TiledMma{}) == 256,
"Cooperative kernel must have TiledMMA operating using 256 threads.");
static_assert(size<0>(TileShape{}) >= 128,
"Cooperative kernel requires Tile Size to be greater than or "
"equal to 128 along the M-dimension.");
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the
* same tile */
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
enum class ProducerWarpRole {
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
int mma_thread_idx = thread_idx % size(TiledMma{});
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Mainloop) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = size(TiledMma{});
mainloop_pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
mainloop_pipeline_params,
ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Epilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
epi_load_pipeline_params.transaction_bytes =
CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id =
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
params_load_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
// scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state =
cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state =
cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = []() {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
} else {
__syncthreads();
return []() {}; // do nothing
}
}();
// Optionally append 1s until problem shape is rank-4 in case it is only
// rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread
// block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
TileScheduler scheduler{params.scheduler};
auto work_tile_info = scheduler.get_current_work();
// In a warp specialized kernel, collectives expose data movement and
// compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue,
shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs =
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(
cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop) {
bool do_load_order_arrive = true;
while (work_tile_info.is_valid()) {
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
work_tile_info = fetch_next_work(work_tile_info, scheduler);
continue;
}
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Get the number of K tiles to compute for this work as well as the
// starting K tile offset of the work.
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
work_tile_info, problem_shape_MNKL, blk_shape);
auto work_k_tile_start =
TileScheduler::get_work_k_tile_start(work_tile_info);
auto k_tile_iter = cute::make_coord_iterator(
idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
collective_mainloop.load(
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx,
block_rank_in_cluster, shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(work_k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline,
mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
collective_epilogue.is_producer_load_needed()) {
while (work_tile_info.is_valid()) {
if (!TileScheduler::requires_separate_reduction(params.scheduler)) {
load_order_barrier.wait();
}
if (TileScheduler::compute_epilogue(work_tile_info,
params.scheduler)) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline, epi_load_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline,
epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
// Do we potentially issue tail arrives for TMA stores, if epilogue load
// is waiting for it
bool do_store_tail = false;
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
work_tile_info, problem_shape_MNKL, blk_shape);
// Allocate the accumulators for the (M,N) blk_shape
//
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
auto accumulators0 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
auto accumulators1 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
collective_mainloop.mma(
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, work_k_tile_count, mma_thread_idx,
shared_storage.tensors.mainloop, params.mainloop);
// Make sure the math instructions are done and free buffers before
// entering the epilogue
collective_mainloop.mma_tail(mainloop_pipeline,
mainloop_pipe_consumer_state,
work_k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(work_k_tile_count);
}
// Index of warp group within consumer warp groups
int consumer_warp_group_idx =
canonical_warp_group_idx() - NumLoadWarpGroups;
// Perform reduction across splits, if needed
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0,
NumMmaWarpGroups, consumer_warp_group_idx);
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1,
NumMmaWarpGroups, consumer_warp_group_idx);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++) {
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
(scale_d1 * accumulators1[i]);
}
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next,
epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
work_tile_info.reduction_subtile_idx());
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
do_store_tail = true;
}
// Get next work tile
work_tile_info = fetch_next_work(work_tile_info, scheduler);
} // Scheduler work fetch loop
if (do_store_tail) {
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
epi_store_pipe_producer_state);
}
} // Consumer Warp Groups End
#endif
}
private:
// Kernel helper function to get next work unit
CUTLASS_DEVICE
typename TileScheduler::WorkTileInfo
fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info,
TileScheduler &scheduler) const {
// Check whether we should continue on with the current work unit. If this
// is the case, the work unit will have been updated in
// continue_current_work to reflect the new tile to be computed.
if (scheduler.continue_current_work(work_tile_info)) {
return work_tile_info;
}
// Get next work tile
scheduler.advance_to_next_work();
return scheduler.get_current_work();
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -0,0 +1,680 @@
/***************************************************************************************************
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/arch/mma_sm90.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "cute/util/debug.hpp"
///////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <class ProblemShape_, class CollectiveMainloop_,
class CollectiveEpilogue_, class TileScheduler_>
class GemmUniversalGated<
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedPingpong,
typename CollectiveMainloop_::
DispatchPolicy::Schedule> &&
CollectiveMainloop_::isGated>> {
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
static_assert(cute::rank(ProblemShape{}) == 3 or
cute::rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using ClusterShape = typename DispatchPolicy::ClusterShape;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using Activation = typename CollectiveMainloop::Activation;
static_assert(ArchTag::kMinComputeCapability >= 90);
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(
!cute::is_same_v<TileScheduler_, StreamKScheduler>,
"Ping-pong kernel does not currently support stream-K scheduler.");
using TileSchedulerTag = TileScheduler_;
using TileScheduler =
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
ClusterShape>::Scheduler;
using TileSchedulerArguments = typename TileScheduler::Arguments;
using TileSchedulerParams = typename TileScheduler::Params;
static constexpr uint32_t NumLoadWarpGroups = 1;
static constexpr uint32_t NumMmaWarpGroups = 2;
static constexpr uint32_t MaxThreadsPerBlock =
CUTE_STATIC_V(size(TiledMma{})) +
(NumMmaWarpGroups * NumThreadsPerWarpGroup);
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
/// Register requirement for Load and Math WGs
static constexpr uint32_t LoadRegisterRequirement = 40;
static constexpr uint32_t MmaRegisterRequirement = 232;
// 1 stage ordered sequence between mainloop and epilogue producer load
// threads
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
// Order Sequence barrier with two stages: one for Mainloop and one for
// Epilogue
static constexpr uint32_t StagesPerMathWarpGroup = 2;
using MathWarpGroupOrderBarrier =
cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
// Kernel level shared memory storage
struct SharedStorage {
struct TensorStorage : cute::aligned_struct<128> {
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
MainloopTensorStorage mainloop;
EpilogueTensorStorage epilogue;
} tensors;
struct PipelineStorage : cute::aligned_struct<16> {
using MainloopPipelineStorage =
typename CollectiveMainloop::PipelineStorage;
using EpiLoadPipelineStorage =
typename CollectiveEpilogue::PipelineStorage;
using MathWarpGroupOrderBarrierStorage =
typename MathWarpGroupOrderBarrier::SharedStorage;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) EpiLoadPipelineStorage epi_load;
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
} pipelines;
};
static constexpr int SharedStorageSize = sizeof(SharedStorage);
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerArguments scheduler{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopParams mainloop{};
EpilogueParams epilogue{};
KernelHardwareInfo hw_info{};
TileSchedulerParams scheduler{};
};
//
// Methods
//
// Convert to underlying arguments. In this case, a simple copy for the
// aliased type.
static Params to_underlying_arguments(Arguments const &args,
void *workspace) {
CUTLASS_TRACE_HOST("to_underlying_arguments():");
(void)workspace;
auto problem_shape = args.problem_shape;
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
// // swap M/N
// get<0>(problem_shape) = get<1>(args.problem_shape);
// get<1>(problem_shape) = get<0>(args.problem_shape);
// }
auto problem_shape_MNKL = append<4>(problem_shape, 1);
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = args.hw_info.sm_count;
if (sm_count <= 0) {
CUTLASS_TRACE_HOST(
" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments "
"KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
args.hw_info.device_id);
}
CUTLASS_TRACE_HOST(
"to_underlying_arguments(): Setting persistent grid SM count to "
<< sm_count);
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
// Calculate workspace pointers
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
void *scheduler_workspace = workspace_ptr;
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *epilogue_workspace = workspace_ptr + workspace_offset;
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
void *mainloop_workspace = nullptr;
return {args.mode,
problem_shape,
CollectiveMainloop::to_underlying_arguments(
args.problem_shape, args.mainloop, mainloop_workspace),
CollectiveEpilogue::to_underlying_arguments(
args.problem_shape, args.epilogue, epilogue_workspace),
hw_info,
TileScheduler::to_underlying_arguments(
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
args.scheduler, scheduler_workspace)};
}
static bool can_implement(Arguments const &args) {
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
(args.mode == GemmUniversalMode::kBatched &&
cute::rank(ProblemShape{}) == 4);
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
"meet the requirements.\n");
return implementable;
}
implementable &=
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &=
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
return implementable;
}
static size_t get_workspace_size(Arguments const &args) {
size_t workspace_size = 0;
workspace_size +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
args.epilogue);
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
return workspace_size;
}
static cutlass::Status
initialize_workspace(Arguments const &args, void *workspace = nullptr,
cudaStream_t stream = nullptr,
CudaHostAdapter *cuda_adapter = nullptr) {
Status status = Status::kSuccess;
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
size_t workspace_offset = 0;
status = TileScheduler::template initialize_workspace<ProblemShape,
ElementAccumulator>(
args.scheduler, workspace_ptr + workspace_offset, stream,
args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset +=
TileScheduler::template get_workspace_size<ProblemShape,
ElementAccumulator>(
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
status = CollectiveEpilogue::initialize_workspace(
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
stream, cuda_adapter);
workspace_offset += CollectiveEpilogue::get_workspace_size(
args.problem_shape, args.epilogue);
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
if (status != Status::kSuccess) {
return status;
}
return status;
}
// Computes the kernel launch grid shape based on runtime parameters
static dim3 get_grid_shape(Params const &params) {
// Given device SM count, set grid size s.t. we do not launch more thread
// blocks than we can run concurrently
TileSchedulerArguments args{};
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
}
args.raster_order =
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
? TileScheduler::RasterOrderOptions::AlongN
: TileScheduler::RasterOrderOptions::AlongM;
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
TileShape{}, ClusterShape{},
params.hw_info, args);
}
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
CUTLASS_DEVICE
void operator()(Params const &params, char *smem_buf) {
using namespace cute;
using X = Underscore;
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
printf("ERROR : Arch conditional MMA instruction used without targeting "
"sm90a compute capability. Aborting.\n");
#else
// Preconditions
static_assert(cute::rank(StrideA{}) == 3,
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideB{}) == 3,
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideC{}) == 3,
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
static_assert(cute::rank(StrideD{}) == 3,
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
"needed, set L stride to Int<0>.");
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
enum class ProducerWarpRole {
Mainloop = 0,
Warp1 = 1,
Epilogue = 2,
Warp3 = 3
};
// Kernel level shared memory storage
SharedStorage &shared_storage =
*reinterpret_cast<SharedStorage *>(smem_buf);
int thread_idx = int(threadIdx.x);
int lane_idx = canonical_lane_idx();
int warp_idx = canonical_warp_idx_sync();
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
int lane_predicate = cute::elect_one_sync();
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
// Issue Tma Descriptor Prefetch from a single thread
if ((warp_idx == 0) && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
}
// Mainloop Load pipeline
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
typename MainloopPipeline::Params mainloop_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Mainloop) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
mainloop_pipeline_params.role =
MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
mainloop_pipeline_params.transaction_bytes =
CollectiveMainloop::TmaTransactionBytes;
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
mainloop_pipeline_params,
ClusterShape{});
// Epilogue Load pipeline
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
typename EpiLoadPipeline::Params epi_load_pipeline_params;
if (warp_group_role == WarpGroupRole::Producer &&
producer_warp_role == ProducerWarpRole::Epilogue) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
}
if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
}
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
epi_load_pipeline_params.transaction_bytes =
CollectiveEpilogue::TmaTransactionBytes;
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
epi_load_pipeline_params);
// Epilogue Store pipeline
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
typename EpiStorePipeline::Params epi_store_pipeline_params;
epi_store_pipeline_params.always_wait = true;
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
params_load_order_barrier.group_id =
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
params_load_order_barrier.group_size = NumThreadsPerWarp;
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
params_load_order_barrier);
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
// DMA Load WG will not participate in these Ordered Barrier syncs
params_math_wg_order_barrier.group_id =
canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
params_math_wg_order_barrier.group_size =
NumThreadsPerWarpGroup; // Number of threads / participants in a group
MathWarpGroupOrderBarrier math_wg_order_barrier(
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
// Initialize starting pipeline states for the collectives
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
// scoreboarding)
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
// For the DMA Load (producer) we start with an opposite phase
// i.e., we skip all waits since we know that the buffer is indeed empty
PipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState epi_load_pipe_producer_state =
cutlass::make_producer_start_state<EpiLoadPipeline>();
PipelineState epi_store_pipe_producer_state =
cutlass::make_producer_start_state<EpiStorePipeline>();
auto cluster_wait_fn = [&]() {
// We need this to guarantee that the Pipeline init is visible
// To all producers and consumer thread blocks in the Cluster
if constexpr (size(ClusterShape{}) > 1) {
cute::cluster_arrive_relaxed();
return []() { cute::cluster_wait(); };
} else {
__syncthreads();
return []() {}; // do nothing
}
}();
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case it is only
// rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
// Get the appropriate blocks for this thread block -- potential for thread
// block locality
TiledMma tiled_mma;
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
// In a warp specialized kernel, collectives expose data movement and
// compute operations separately
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue(params.epilogue,
shared_storage.tensors.epilogue);
// Prepare and partition the input tensors. Expects a tuple of tensors
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
auto load_inputs =
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
static_assert(
cute::tuple_size_v<decltype(load_inputs)> >= 3,
"Output of load_init must have at least three elements (A, B, Aux)");
// Extract out partitioned A and B.
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
Tensor gAux_xkl = get<2>(load_inputs);
// Get pipeline stage increments from tensor shapes
auto k_tile_count = size<3>(gA_mkl);
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
TileScheduler scheduler{params.scheduler};
if (warp_group_role == WarpGroupRole::Consumer1) {
// Advance 2nd Math WG to the next work tile for the startup
scheduler.advance_to_next_work();
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
mainloop_pipe_consumer_state.advance(k_tile_count);
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
}
auto work_tile_info = scheduler.get_current_work();
// Wait for all thread blocks in the Cluster
cluster_wait_fn();
if (warp_group_role == WarpGroupRole::Producer) {
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
// Mainloop Producer Warp
if (producer_warp_role == ProducerWarpRole::Mainloop) {
bool do_load_order_arrive = true;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
collective_mainloop.load(
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
block_rank_in_cluster, shared_storage.tensors.mainloop);
// Update starting pipeline state for the next tile
mainloop_pipe_producer_state.advance(k_tile_count);
// Signal for the epilogue load warp to begin
if (do_load_order_arrive) {
load_order_barrier.arrive();
do_load_order_arrive = false;
}
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_mainloop.load_tail(mainloop_pipeline,
mainloop_pipe_producer_state);
} // Mainloop Producer Warp End
// Epilogue Producer Warp
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
collective_epilogue.is_producer_load_needed()) {
load_order_barrier.wait();
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
epi_load_pipe_producer_state = collective_epilogue.load(
epi_load_pipeline, epi_load_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
shared_storage.tensors.epilogue);
// Get next work tile
scheduler.advance_to_next_work();
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
// Make sure all Consumer Warp Groups have been waited upon
collective_epilogue.load_tail(epi_load_pipeline,
epi_load_pipe_producer_state);
} // Epilogue Producer Warp End
} // Producer Warp Group End
else if (warp_group_role == WarpGroupRole::Consumer0 ||
warp_group_role == WarpGroupRole::Consumer1) {
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
float scale_d0 = params.mainloop.scale_d0;
float scale_d1 = params.mainloop.scale_d1;
while (work_tile_info.is_valid()) {
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
// n-shape
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
// Allocate the accumulators for the (M,N) blk_shape
Tensor accumulators0 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
Tensor accumulators1 = partition_fragment_C(
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
// Order two Math WG's MMA one after the other, helps hide Epilogue
math_wg_order_barrier.wait();
collective_mainloop.mma(
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
accumulators1, k_tile_count, warp_group_thread_idx,
shared_storage.tensors.mainloop, params.mainloop);
// Cue for next Math WG's MMA to start
math_wg_order_barrier.arrive();
// Make sure the math instructions are done and free buffers before
// entering the epilogue
collective_mainloop.mma_tail(
mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
// Update starting mainloop pipeline state for the next tile
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
Activation elt_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators0); i++) {
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
(scale_d1 * accumulators1[i]);
}
// Order two Math WG's Epilogue one after the other
math_wg_order_barrier.wait();
// Epilogue and write to gD
auto [epi_load_pipe_consumer_state_next,
epi_store_pipe_producer_state_next] =
collective_epilogue.store(
epi_load_pipeline, epi_load_pipe_consumer_state,
epi_store_pipeline, epi_store_pipe_producer_state,
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
tiled_mma, warp_group_thread_idx,
shared_storage.tensors.epilogue);
// TMA store pipeline wait is only visible to TMA-issuing warp, so for
// multiple-consumer kernels we need to wait for all TMA stores to
// complete before issuing consumer order barrier arrives to ensure next
// math consumer doesn't overwrite smem of in-flight TMA stores of
// current consumer.
auto [epi_load_pipe_consumer_state_next_,
epi_store_pipe_producer_state_next_] =
collective_epilogue.store_tail(
epi_load_pipeline, epi_load_pipe_consumer_state_next,
epi_store_pipeline, epi_store_pipe_producer_state_next);
// Update starting load/store pipeline states for the next tile
// state has already been incremented by 1 tile in collective calls,
// advance once again for ping pong
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
epi_load_pipe_consumer_state.advance(c_tile_count);
epi_store_pipe_producer_state.advance(d_tile_count);
// Cue for next Math WG's Epilogue to start
math_wg_order_barrier.arrive();
// Get next work tile
scheduler.advance_to_next_work(NumMmaWarpGroups);
work_tile_info = scheduler.get_current_work();
} // Scheduler work fetch loop
} // Consumer Warp Groups End
#endif
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel

View File

@@ -77,6 +77,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
@@ -125,6 +126,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
template <
@@ -148,7 +150,7 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
@@ -179,6 +181,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
template <
@@ -234,6 +237,7 @@ public:
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
template <
@@ -346,6 +350,131 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, El
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -19,13 +19,11 @@
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
@@ -197,6 +195,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
template <
/// Layout type for A matrix operand
@@ -244,6 +243,9 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -265,7 +267,7 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
@@ -296,6 +298,7 @@ public:
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
template <
/// Layout type for A matrix operand
@@ -318,11 +321,11 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
@@ -348,6 +351,131 @@ public:
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Number of stages used in the multistage mainloop
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,237 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/threadblock/mma_base.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount =
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
static_assert((kWarpGemmIterations % 2) == 0,
"Inner loop iteration must be an even number.");
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
// w uint8; local_scale uint8;
constexpr static int kZippedRowsPerStages =
Shape::kK / 4 + (Shape::kK + 127) / 128;
// code_scale float; code_zp float; super_scale ElementB
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
sizeof_bits<typename Operator::ElementB>::value / 8;
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for quanted B operand
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
/// Buffer for unzip B operand
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
operand_unzip_B;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
CUTLASS_HOST_DEVICE
typename Operator::ElementB *operand_unzip_B_ptr() {
return operand_unzip_B.data();
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,807 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass_extensions/arch/memory_copy_sm80.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
// accuracy, where each mainloop iteration first accumulates into a temporary
// set of freshly-cleared accumulators, which are subsequently added to the
// final accumulator set.
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
};
private:
// Structure encapsulating pipeline state live from one iteration to the next
struct PipeState {
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentA warp_loaded_frag_A_[2];
WarpTransformedFragmentA warp_transformed_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];
};
private:
//
// Data members
//
/// Warp-level MMA operator
Operator warp_mma_;
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
uint8_t* column_wise_smem_ptr_B_;
uint8_t* smem_zipped_ptr_B_;
int smem_zipped_bytes_per_stage_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage &shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
}
/// Advance shared memory read-iterators to the next stage
CUTLASS_DEVICE
void advance_smem_read_stage()
{
++smem_read_stage_idx_;
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
smem_read_stage_idx_ = 0;
}
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
template <typename TileDequanterB>
CUTLASS_DEVICE
void advance_smem_write_stage(
IteratorA &iterator_A,
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B)
{
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
//iterator_B.add_tile_offset({1, 0});
tile_dequanter_B.AddTileOffset({1, 0});
// Advance shared iterators
smem_iterator_A_.add_tile_offset({0, 1});
//smem_iterator_B_.add_tile_offset({1, 0});
// Increment shared memory write stage index
++smem_write_stage_idx_;
if (smem_write_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx_ = 0;
}
}
CUTLASS_DEVICE
void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
}
template <bool GlobalToSharedB>
CUTLASS_DEVICE
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
__syncthreads();
}
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) {
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
}
template <bool GlobalToSharedB, bool InitStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (InitStage) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
} else {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
}
++iterator_B;
}
++this->smem_iterator_B_;
}
__syncthreads();
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
template <typename TileDequanterB>
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B,
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
{
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
// Async copy zipped B to shared memory.
copy_tiles_and_advance_per_stage_A(iterator_A);
// Async copy zipped B to shared memory.
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
// Move to the next write stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
}
/// Wait until we have at least one completed global fetch stage
CUTLASS_DEVICE
void gmem_wait()
{
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
}
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
template <typename TileDequanterB>
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
// Unpack and dequant the first stage of B.
int unpack_stage = stage - Base::kStages + 2;
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, unpack_stage);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
}
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
}
// Execute the current warp-tile of MMA operations
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
if (warp_mma_k == 0) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
pipe_state.tmp_accum_.clear();
}
} else {
warp_mma_(
accum,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
accum
);
}
// Except for the last warp-tile, all warp-tiles issue their share of
// global->shared fragment copies
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
if (warp_mma_k == 0) {
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
}
}
// The second-to-last warp-tile also:
// - performs the last warp-tile's share of global->shared fragment copies
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Wait until we have at least one completed global fetch stage
gmem_wait();
// Move to the next global fetch stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_read_stage();
// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
}
// The last warp-tile also converts the shared memory fragments used by
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
}
}
}
/// Perform the specified number of threadblock mainloop iterations of matrix
/// multiply-accumulate. Assumes prologue has been initiated.
template <typename TileDequanterB>
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
{
PipeState pipe_state;
// Unpack and dequant the first stage of B.
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
++this->warp_tile_iterator_B_;
// Transform, if necessary, the first warp-tile's shared memory fragments
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[0],
pipe_state.warp_transformed_frag_B_[0],
pipe_state.warp_loaded_frag_A_[0],
pipe_state.warp_loaded_frag_B_[0]);
if (Detail::kStagedAccumulation) {
pipe_state.tmp_accum_.clear();
}
int stage = Base::kStages - 1;
// Mainloop
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
mac_loop_iter(
pipe_state,
accum,
iterator_A,
iterator_B,
tile_dequanter_B,
gemm_k_iterations,
stage);
stage += 1;
}
if (Detail::kStagedAccumulation) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
/// Prepares the class for another prologue.
CUTLASS_DEVICE
void wind_down()
{
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
// just decrement one tile, but not all iterators implement --() decrement.)
#pragma unroll
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
}
smem_read_stage_idx_++;
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
if (smem_read_stage_idx_ > 1)
{
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
template <typename TileDequanterB>
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC &accum,
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< pre-load and dequantize B to shared memory
TileDequanterB tile_dequanter_B,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
gmem_wait();
// Initialize destination accumulators with source accumulators
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,130 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/gemm_coord.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
int Stages, int NumThreads, WintQuantMethod Method>
struct TileDequanter {
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
using QuantArguments = typename WeightQuantTraits::Arguments;
using UnzipAndDequantFunctor =
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
static constexpr bool kUseSharedMemory = true;
static constexpr int kRows = Rows;
static constexpr int kColumns = Columns;
static constexpr int kStages = Stages;
MmaElementT *out_smem_ptr{nullptr};
char *pointer{nullptr};
int64_t ldm{0};
cutlass::MatrixCoord tb_offset;
cutlass::MatrixCoord extent;
ScaleElementT *super_scale_ptr{nullptr};
cutlass::MatrixCoord tb_offset_scale;
QuantArguments quant_args;
int64_t block_start_rows[kStages];
bool need_preload{true};
UnzipAndDequantFunctor unzip_functor;
CUTLASS_DEVICE
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
const cutlass::MatrixCoord &extent,
const cutlass::MatrixCoord &tb_offset,
ScaleElementT *super_scale_ptr,
const cutlass::MatrixCoord &tb_offset_scale,
const QuantArguments &quant_args)
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
CUTLASS_DEVICE
MmaElementT *GetOutPtr() { return out_smem_ptr; }
CUTLASS_DEVICE
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
tb_offset.row() += tile_offset.row() * kRows;
tb_offset.column() += tile_offset.column() * kColumns;
tb_offset_scale.column() += tile_offset.column() * kColumns;
}
CUTLASS_DEVICE
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
if (tb_offset.row() >= extent.row() ||
tb_offset.column() >= extent.column()) {
return;
}
block_start_rows[stage % kStages] = tb_offset.row();
using ZippedT = typename WeightQuantTraits::WeightType;
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
tb_offset.column();
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
(tb_offset.row() / 128) * ldm +
tb_offset_scale.column();
const float *code_scale_ptr =
quant_args.code_scale_ptr + tb_offset_scale.column();
const float *code_zp_ptr =
quant_args.code_zp_ptr + tb_offset_scale.column();
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
scale_ptr, &args, ldm, need_preload);
need_preload = false;
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
CUTLASS_DEVICE
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int64_t block_start_row = block_start_rows[stage % kStages];
if (block_start_row >= extent.row()) {
return;
}
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,447 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <cuda_runtime.h>
#include "cutlass/arch/memory.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/wint_type_traits.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename T, int N>
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
template <typename T, WintQuantMethod QuantMethod, int TileRows,
int TileColumns, int NumThreads = 128>
struct UnzipAndDequantFunctor {
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
T *out_ptr, const int64_t in_stride) {}
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
TileColumns, NumThreads> {
using ZippedT = uint16_t;
using ScaleComputeT = float;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kZippedGroupSize = 10;
static constexpr int32_t kNumPackedValues = 7;
static constexpr int32_t kWeightMask = 0x7;
static constexpr int32_t kLocalScaleMask = 0x1FFF;
static constexpr int32_t kBZP = 4;
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
ScaleComputeT scale) {
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
int32_t value = shifted_value - kBZP;
ScaleComputeT scaled_value = static_cast<ScaleComputeT>(value) * scale;
return static_cast<T>(scaled_value);
}
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
T *out_ptr, const int64_t in_stride) {
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
ScaleComputeT super_scale =
static_cast<ScaleComputeT>(super_scale_ptr[col]);
#pragma unroll
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
// the last row in group
int zipped_row_last = group_id * 10 + 9;
int zipped_offset_last = zipped_row_last * in_stride + col;
int32_t zipped_value_last =
static_cast<int32_t>(in_ptr[zipped_offset_last]);
ScaleComputeT local_scale =
static_cast<ScaleComputeT>(zipped_value_last & kLocalScaleMask);
ScaleComputeT scale = local_scale * super_scale;
#pragma unroll
for (int zipped_row_in_group = 0; zipped_row_in_group < 9;
++zipped_row_in_group) {
int zipped_row = group_id * 10 + zipped_row_in_group;
int zipped_offset = zipped_row * in_stride + col;
int32_t zipped_value = static_cast<int32_t>(in_ptr[zipped_offset]);
int row_in_group = group_id * 64 + zipped_row_in_group * 7;
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) {
int32_t shift_bit = shift_bits[shift_bit_id];
T value = Compute(zipped_value, shift_bit, scale);
out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value;
}
}
int row_in_group_last = group_id * 64 + 63;
T value_last = Compute(zipped_value_last, shift_bits[0], scale);
out_ptr[row_in_group_last * TileColumns + col] = value_last;
}
}
__syncthreads();
}
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
TileColumns, NumThreads> {
using ZippedT = uint8_t;
using ScaleComputeT = float;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kPackNum = 4;
static constexpr int32_t kWeightMask = 0x3F;
static constexpr int32_t kLocalScaleMask = 0xF;
static constexpr int32_t kBZP = 32;
// weight [16, N] uint8_t
// local_scale [1, N] uint8_t
// code_scale [N] float
// code_zp [N] float
// super_scale [N] T
// code_scale, code_zp and super_scale
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
// zipped weights and local_scale
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
struct Arguments {
uint8_t *weight_ptr;
uint8_t *local_scale_ptr;
float *code_scale_ptr;
float *code_zp_ptr;
T *super_scale_ptr;
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
__device__ explicit Arguments(uint8_t *smem_ptr) {
SetZippedPtrs(smem_ptr);
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
}
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
SetZippedPtrs(zipped_smem_ptr);
SetColumnWisePtrs(column_wise_smem_ptr);
}
__device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) {
weight_ptr = zipped_smem_ptr;
local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns;
}
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
}
};
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
if (need_preload) {
if (g_super_scale_ptr) {
args->super_scale_ptr[col] = g_super_scale_ptr[col];
} else {
args->super_scale_ptr[col] = static_cast<T>(1);
}
args->code_scale_ptr[col] = g_code_scale_ptr[col];
args->code_zp_ptr[col] = g_code_zp_ptr[col];
}
#pragma unroll
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
int local_scale_offset = ls_row_id * in_stride + col;
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
}
#pragma unroll
for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) {
int s_zipped_offset = zipped_row * TileColumns + col;
int g_zipped_offset = zipped_row * 4 * in_stride + col;
args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset];
}
}
__syncthreads();
}
__device__ void LoadAsync(const uint8_t *g_weight_ptr,
const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr,
const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
int tid = threadIdx.x;
constexpr int kBytesPerThread = 16; // 16B per thread
constexpr int weight_size = TileRows / 4 * TileColumns;
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
constexpr int code_scale_size = sizeof(float) * TileColumns;
constexpr int code_zp_size = sizeof(float) * TileColumns;
constexpr int super_scale_size = sizeof(T) * TileColumns;
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
constexpr int total_tasks = total_size / kBytesPerThread;
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
static_assert(TileColumns % weight_threads == 0,
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
static_assert(TileColumns % local_scale_threads == 0,
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
if (tid < weight_threads) {
constexpr int weight_per_thread_size = weight_size / weight_threads;
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
}
} else if (tid < weight_threads + local_scale_threads) {
constexpr int start_thread_id = weight_threads;
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
}
} else if (need_preload) {
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads;
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
if (g_super_scale_ptr) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
}
}
}
}
}
__device__ void Compute(const Arguments &args, T *out_ptr,
const int64_t block_start_row) {
int32_t shift_bits[4] = {9, 6, 3, 0};
int tid = threadIdx.x;
#pragma unroll
for (int col = tid; col < TileColumns; col += NumThreads) {
ScaleComputeT super_scale =
static_cast<ScaleComputeT>(args.super_scale_ptr[col]);
ScaleComputeT code_scale =
static_cast<ScaleComputeT>(args.code_scale_ptr[col]);
ScaleComputeT code_zp = static_cast<ScaleComputeT>(args.code_zp_ptr[col]);
#pragma unroll
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
int local_scale_offset = (group_id / 2) * TileColumns + col;
int32_t local_scale =
static_cast<int32_t>(args.local_scale_ptr[local_scale_offset]);
ScaleComputeT zipped_value[16];
#pragma unroll
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col;
zipped_value[zipped_row] =
static_cast<ScaleComputeT>(args.weight_ptr[zipped_offset]);
}
int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4;
int32_t shifted_local_scale =
(local_scale >> local_scale_shift) & kLocalScaleMask;
ScaleComputeT scale =
static_cast<ScaleComputeT>(shifted_local_scale) * super_scale;
#pragma unroll
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
int32_t decode_value =
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
static_cast<ScaleComputeT>(0.5)));
int row = group_id * 64 + zipped_row * 4;
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
int32_t shift_bit = shift_bits[shift_bit_id];
int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask;
ScaleComputeT value =
static_cast<ScaleComputeT>(shifted_value - kBZP);
out_ptr[(row + shift_bit_id) * TileColumns + col] =
static_cast<T>(scale * value);
}
}
}
}
__syncthreads();
}
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
const int64_t block_start_row) {
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
constexpr int RowStride = NumThreads * N / TileColumns;
constexpr int kNumIters = kNumWeightsPerThread / N;
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
int tid = threadIdx.x;
int begin_col_id = (tid * N) % TileColumns;
int begin_row_id = (tid * N) / TileColumns;
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
UnzipArray<uint8_t, N> local_scales =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
UnzipArray<uint8_t, N> zipped_values[2];
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
zipped_values[0] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
UnzipArray<T, N> super_scales =
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
UnzipArray<float, N> code_scales =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
UnzipArray<float, N> code_zps =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
// special for TileRows = 64
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
UnzipArray<ScaleComputeT, N> scales;
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t shifted_local_scale =
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
scales[i] =
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
}
#pragma unroll
for (int iter_id = 0; iter_id < kNumIters; ++iter_id) {
int zipped_row = begin_row_id + iter_id * RowStride;
int row = zipped_row * 4;
if (iter_id < kNumIters - 1) {
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
zipped_values[(iter_id + 1) & 1] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
}
UnzipArray<T, N> outs[4];
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t decode_value =
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
+ code_zps[i] + decode_value_zp));
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
outs[0][i] = static_cast<T>(scales[i] * value_0);
outs[1][i] = static_cast<T>(scales[i] * value_1);
outs[2][i] = static_cast<T>(scales[i] * value_2);
outs[3][i] = static_cast<T>(scales[i] * value_3);
}
#pragma unroll
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
UnzipArray<T, N> *tmp_out_ptr = reinterpret_cast<UnzipArray<T, N> *>(
out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id);
*tmp_out_ptr = outs[shift_bit_id];
}
}
__syncthreads();
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,140 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_types.h"
namespace cutlass {
enum WintQuantMethod {
kNone = 0,
kWeightOnlyInt8 = 1,
kWeightOnlyInt4 = 2,
kWeightOnlyInt25 = 3,
kWeightOnlyInt2 = 4
};
// Convert CUDA data type to cutlass data type
template <typename T> struct CutlassDataType {
using Type = T;
};
template <> struct CutlassDataType<half> {
using Type = cutlass::half_t;
};
template <> struct CutlassDataType<__nv_bfloat16> {
using Type = cutlass::bfloat16_t;
};
template <typename ElementT, WintQuantMethod Method> struct WintQuantTraits;
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kNone> {
using WeightType = ElementT;
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod = WintQuantMethod::kNone;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt8> {
using WeightType = uint8_t;
using MmaKernelType = uint8_t;
using MmaWeightType = uint8_t;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt8;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt4> {
using WeightType = cutlass::uint4b_t;
using MmaKernelType = cutlass::uint4b_t;
using MmaWeightType = cutlass::uint4b_t;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt4;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) { return dim; }
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt25> {
using WeightType = uint16_t;
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt25;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kNumPackedValues = 7;
static constexpr int32_t kPackedSize = 10;
struct Arguments {};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) {
return dim * kPackedSize / kGroupSize;
}
};
template <typename ElementT>
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
using WeightType = uint8_t;
using MmaKernelType = cutlass::uint2b_t;
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
static constexpr WintQuantMethod kQuantMethod =
WintQuantMethod::kWeightOnlyInt2;
static constexpr int32_t kGroupSize = 64;
static constexpr int32_t kNumPackedValues = 4;
static constexpr int32_t kPackedSize = 16;
struct Arguments {
const uint8_t *local_scale_ptr; // quanted 4-bits
const float *code_scale_ptr;
const float *code_zp_ptr;
};
CUTLASS_DEVICE
static int64_t CaclPackedDim(int64_t dim) {
return dim * kPackedSize / kGroupSize;
}
};
} // namespace cutlass