mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-30 11:26:39 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
250
custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h
Normal file
250
custom_ops/gpu_ops/cutlass_extensions/arch/memory_copy_sm80.h
Normal file
@@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Architecture-specific operators on memory added for SM80
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/arch/cache_operation.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace arch {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always,
|
||||
bool GlobalToShared = true>
|
||||
struct copy;
|
||||
|
||||
/// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate
|
||||
/// the entire transfer, zeros are written to SMEM if the guard predicate is false.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes,
|
||||
/// Cache operation
|
||||
CacheOperation::Kind cache_op = CacheOperation::Always,
|
||||
bool GlobalToShared = true>
|
||||
struct copy_zfill;
|
||||
|
||||
/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
|
||||
///
|
||||
/// cp.async
|
||||
///
|
||||
template <int N, bool GlobalToShared = true>
|
||||
struct copy_wait;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Always>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Always, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy
|
||||
CUTLASS_DEVICE
|
||||
copy(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, true> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
cp_async_zfill<SizeInBytes, CacheOperation::Global>(smem_ptr, global_ptr, pred_guard);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <
|
||||
/// Size of the access in bytes
|
||||
int SizeInBytes>
|
||||
struct copy_zfill<SizeInBytes, CacheOperation::Global, false> {
|
||||
|
||||
/// Copy with zero fill
|
||||
CUTLASS_DEVICE
|
||||
copy_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
using AccessType = Array<uint8_t, SizeInBytes>;
|
||||
|
||||
if (pred_guard) {
|
||||
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
|
||||
}
|
||||
else {
|
||||
AccessType zeros;
|
||||
zeros.clear();
|
||||
*static_cast<AccessType *>(smem_ptr) = zeros;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
|
||||
template <bool GlobalToShared>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence() {}
|
||||
|
||||
template <>
|
||||
CUTLASS_DEVICE
|
||||
void copy_fence<true>() {
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, false> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() {}
|
||||
};
|
||||
|
||||
/// Partial specialization
|
||||
template <int N>
|
||||
struct copy_wait<N, true> {
|
||||
|
||||
CUTLASS_DEVICE
|
||||
copy_wait() { cp_async_wait<N>(); }
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,460 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_row_array = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||
int group, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, group(group)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
int group;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
l,
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_col_array = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
int group,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
group(group),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
int group;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
l,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,500 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
||||
// https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either
|
||||
// row/column or scalar broadcasting where the tensor being loaded from is
|
||||
// always passed in via a device pointer. This lets one compiled kernel handle
|
||||
// all cases of per-tensor or per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graph
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::epilogue::threadblock {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->row_broadcast) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL
|
||||
>
|
||||
struct VisitorRowOrZeroBroadcast {
|
||||
|
||||
// This struct has been modified to remove null_default (because it's always 0)
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Global load type
|
||||
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
||||
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
||||
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gRow,
|
||||
RTensor&& tC_rRow,
|
||||
CTensor&& tC_cRow,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
||||
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
||||
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
||||
n(get<1>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gRow;
|
||||
RTensor tC_rRow;
|
||||
CTensor tC_cRow;
|
||||
Params const* params_ptr;
|
||||
int n;
|
||||
|
||||
// This function is modified from VisitorRowBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rRow);
|
||||
auto src_v = filter(tC_gRow);
|
||||
auto coord_v = filter(tC_cRow);
|
||||
auto dst_v = filter(tC_rRow);
|
||||
|
||||
if (params_ptr->ptr_row != nullptr) {
|
||||
// In this case we are loading from a row vector and broadcasting
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
bool guard = get<1>(coord_v(i)) < n;
|
||||
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
||||
dst_v(i), (void const*)&src_v(i), guard);
|
||||
}
|
||||
} else {
|
||||
// In this case we are broadcasting 0
|
||||
VecType filled_vec;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < VecLength; i++) {
|
||||
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(src_v); ++i) {
|
||||
if (get<1>(coord_v(i)) < n) {
|
||||
dst_v(i) = filled_vec;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
||||
return rRow_frg(column_idx);
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mRow = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_row),
|
||||
problem_shape,
|
||||
params_ptr->dRow);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN
|
||||
Tensor tC_gRow = recast<VecType>(
|
||||
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
||||
)(_,_,_0{},_0{},_0{},_0{});
|
||||
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cRow = make_identity_tensor(mRow.shape());
|
||||
Tensor tC_cRow = outer_partition(
|
||||
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
||||
Shape<Int<VecLength>>{},
|
||||
(_0{})
|
||||
);
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gRow), decltype(tC_rRow),
|
||||
decltype(tC_cRow), ProblemShape>(
|
||||
cute::move(tC_gRow),
|
||||
cute::move(tC_rRow),
|
||||
cute::move(tC_cRow),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
class ThreadMap,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>
|
||||
>
|
||||
struct VisitorColOrScalarBroadcast {
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct SharedStorage { };
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params_ptr(¶ms) { }
|
||||
|
||||
Params const* params_ptr;
|
||||
|
||||
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct Callbacks : EmptyCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
Callbacks(
|
||||
GTensor&& tC_gCol,
|
||||
RTensor&& tC_rCol,
|
||||
CTensor&& tC_cCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const* params_ptr
|
||||
):
|
||||
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
||||
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
||||
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params_ptr(params_ptr) { }
|
||||
|
||||
GTensor tC_gCol;
|
||||
RTensor tC_rCol;
|
||||
CTensor tC_cCol;
|
||||
Params const* params_ptr;
|
||||
int m;
|
||||
|
||||
// This function is modified from VisitorColBroadcast
|
||||
CUTLASS_DEVICE void
|
||||
begin_epilogue() {
|
||||
clear(tC_rCol);
|
||||
|
||||
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tC_cCol(i)) < m;
|
||||
}
|
||||
|
||||
if (params_ptr->col_broadcast) {
|
||||
// In this case we are loading from a column vector and broadcasting
|
||||
copy_if(pred, tC_gCol, tC_rCol);
|
||||
} else {
|
||||
// In this case we are loading from a scalar and broadcasting
|
||||
auto dst_v = filter(tC_rCol);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(dst_v); ++i) {
|
||||
if (pred(i)) {
|
||||
dst_v(i) = *(params_ptr->ptr_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE auto // returns an Array
|
||||
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
||||
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
||||
return frg_col;
|
||||
}
|
||||
};
|
||||
|
||||
template <class ProblemShape>
|
||||
CUTLASS_DEVICE auto
|
||||
get_callbacks(
|
||||
gemm::GemmCoord threadblock_tile_offset,
|
||||
int thread_idx,
|
||||
ProblemShape problem_shape
|
||||
) {
|
||||
Tensor mCol = make_tensor(
|
||||
make_gmem_ptr(params_ptr->ptr_col),
|
||||
problem_shape,
|
||||
params_ptr->dCol);
|
||||
|
||||
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
||||
Tensor tC_gCol = group_modes<1,4>(
|
||||
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
||||
|
||||
// Generate the pred tensor
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tC_cCol = group_modes<1,4>(
|
||||
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
||||
|
||||
return Callbacks<
|
||||
decltype(tC_gCol), decltype(tC_rCol),
|
||||
decltype(tC_cCol), ProblemShape>(
|
||||
cute::move(tC_gCol),
|
||||
cute::move(tC_rCol),
|
||||
cute::move(tC_cCol),
|
||||
problem_shape,
|
||||
params_ptr
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_row = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcast {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
Element const* ptr_col = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy::c2x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::Tensor const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(
|
||||
tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor,
|
||||
ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor,
|
||||
RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
}
|
||||
else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr =
|
||||
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
paddle._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c2x
|
||||
@@ -0,0 +1,453 @@
|
||||
// adapted from: https://github.com/vllm-project/vllm/blob/118ff921118cc81061a2af865a1e13840ceb6792/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace fastdeploy::c3x {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T> struct identity {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T lhs) const { return lhs; }
|
||||
};
|
||||
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct TrivialEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
using Compute = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity, ElementD, ElementAcc,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::fusion::Sm90EVT<Compute, Accum>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
template <typename... Args> static ArgumentType prepare_args(Args... args) {
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::Tensor const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr = static_cast<T *>(const_cast<void *>(tensor.data()));
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
||||
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(paddle::optional<paddle::Tensor> const &tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto *data_ptr =
|
||||
tensor ? static_cast<T *>(const_cast<void *>(tensor->data())) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(const T *const *data_ptr, bool do_broadcast) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||
return Arguments{data_ptr, do_broadcast};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
paddle.scaled_mm_.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be
|
||||
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogueBias, but the
|
||||
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
||||
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueColumnBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template ColLoad<ElementD>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(paddle::Tensor const &a_scales,
|
||||
paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType
|
||||
prepare_args(paddle::Tensor const &a_scales, paddle::Tensor const &b_scales,
|
||||
paddle::Tensor const &azp_adj, paddle::Tensor const &azp,
|
||||
paddle::optional<paddle::Tensor> const &bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||
to arrays containing different scales used in group gemm. The number of
|
||||
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||
group size.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
static ArgumentType prepare_args(float const *const *a_scales_ptr,
|
||||
float const *const *b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
b_scales_ptr, b_row_broadcast);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace fastdeploy::c3x
|
||||
@@ -0,0 +1,284 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
||||
|
||||
// SM90 Collective Builders should be used only starting CUDA 12.0
|
||||
#if (__CUDACC_VER_MAJOR__ >= 12)
|
||||
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Returns the maximum number of smem tiles that can be used with a given smem
|
||||
// capacity, or overrides with manual count.
|
||||
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK,
|
||||
bool SwapAB, int carveout_bytes>
|
||||
constexpr int compute_stage_count_or_override_gated(
|
||||
StageCountAutoCarveout<carveout_bytes> stage_count) {
|
||||
// 32 bytes to account for barriers etc.
|
||||
constexpr int stage_barrier_bytes = 32;
|
||||
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
|
||||
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
|
||||
constexpr int stage_bytes = [&]() -> int {
|
||||
if constexpr (SwapAB) {
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
|
||||
8 +
|
||||
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
stage_barrier_bytes;
|
||||
} else {
|
||||
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 +
|
||||
(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) /
|
||||
8 +
|
||||
stage_barrier_bytes;
|
||||
}
|
||||
}();
|
||||
|
||||
return (CapacityBytes - carveout_bytes) / stage_bytes;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<
|
||||
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
|
||||
cute::enable_if_t<
|
||||
(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>) &&
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB,
|
||||
GmemLayoutB>()>> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
|
||||
detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm =
|
||||
(cute::is_same_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
|
||||
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only "
|
||||
"compatible with FP8 FastAccum version right now\n");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>,
|
||||
tfloat32_t, ElementA>;
|
||||
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>,
|
||||
tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA =
|
||||
detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
|
||||
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<MmaElementA, MmaElementB, ElementAccumulator,
|
||||
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, MmaElementA, decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, MmaElementB, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
detail::compute_stage_count_or_override_gated<
|
||||
detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
|
||||
TileShape_MNK, SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>,
|
||||
/* For FP8 use a separate mainloop compared to other datatypes */
|
||||
cute::conditional_t<
|
||||
IsFP8Input,
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<
|
||||
PipelineStages, ClusterShape_MNK, KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<
|
||||
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
|
||||
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
|
||||
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
|
||||
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB,
|
||||
class GmemLayoutB, int AlignmentB, class ElementAccumulator,
|
||||
class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation, bool SwapAB>
|
||||
struct CollectiveBuilderGated<
|
||||
arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType, Activation, SwapAB,
|
||||
cute::enable_if_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongFP8FastAccum> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
cute::is_same_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>> {
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
|
||||
detail::tma_alignment_bytes>(),
|
||||
"Not meet TMA alignment requirement yet\n");
|
||||
static_assert(
|
||||
detail::is_input_fp8<ElementA, ElementB>(),
|
||||
"Only FP8 datatypes are compatible with these kernel schedules\n");
|
||||
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
|
||||
static_assert(
|
||||
!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
|
||||
"Not supported for fp8 non-TN warp specialized kernels yet\n");
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA =
|
||||
detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB =
|
||||
detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm =
|
||||
(cute::is_same_v<
|
||||
KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8FastAccum> ||
|
||||
IsArrayOfPointersGemm,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator,
|
||||
TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
|
||||
AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(
|
||||
shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB =
|
||||
decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})),
|
||||
decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr int PipelineStages =
|
||||
detail::compute_stage_count_or_override_gated<
|
||||
detail::sm90_smem_capacity_bytes, ElementA, ElementB, TileShape_MNK,
|
||||
SwapAB>(StageCountType{});
|
||||
using DispatchPolicy = cute::conditional_t<
|
||||
IsArrayOfPointersGemm,
|
||||
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>,
|
||||
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK,
|
||||
KernelScheduleType>>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMmaGated<
|
||||
DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
|
||||
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA,
|
||||
SmemLayoutAtomA, SmemCopyAtomA, cute::identity, GmemTiledCopyB,
|
||||
SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,60 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType,
|
||||
template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false, class Enable = void>
|
||||
struct CollectiveBuilderGated {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,62 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/detail/dependent_false.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class DispatchPolicy, class TileShape, class ElementA, class StrideA,
|
||||
class ElementB, class StrideB, class TiledMma, class GmemTiledCopyA,
|
||||
class SmemLayoutAtomA, class SmemCopyAtomA, class TransformA,
|
||||
class GmemTiledCopyB, class SmemLayoutAtomB, class SmemCopyAtomB,
|
||||
class TransformB,
|
||||
template <class /* ElementCompute */> class Activation,
|
||||
bool SwapAB = false>
|
||||
struct CollectiveMmaGated {
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>,
|
||||
"Could not find a mainloop specialization.");
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,713 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule,
|
||||
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
|
||||
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
|
||||
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_,
|
||||
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<
|
||||
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
|
||||
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
|
||||
SwapAB_> {
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy =
|
||||
MainloopSm90TmaGmmaWarpSpecialized<Stages, ClusterShape, KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
|
||||
typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2,
|
||||
"Specialization requires Stages set to value 2 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for "
|
||||
"this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
|
||||
// For all other types, cast to size equivalent uint type to avoid any
|
||||
// rounding by TMA.
|
||||
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
|
||||
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
|
||||
using InternalElementA =
|
||||
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
|
||||
uint_bit_t<sizeof_bits_v<ElementA>>>;
|
||||
using InternalElementB =
|
||||
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
|
||||
uint_bit_t<sizeof_bits_v<ElementB>>>;
|
||||
using InternalElementAux =
|
||||
cute::conditional_t<SwapAB, InternalElementA, InternalElementB>;
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA,
|
||||
cute::cosize_v<SmemLayoutA>>
|
||||
smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB,
|
||||
cute::cosize_v<SmemLayoutB>>
|
||||
smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const *ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const *ptr_B0;
|
||||
ElementB const *ptr_B1;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<InternalElementA const *>(nullptr),
|
||||
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<InternalElementB const *>(nullptr),
|
||||
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const &problem_shape,
|
||||
Arguments const &args, void *workspace) {
|
||||
(void)workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
|
||||
// only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<InternalElementA const *>(args.ptr_A);
|
||||
auto ptr_B0 = reinterpret_cast<InternalElementB const *>(args.ptr_B0);
|
||||
|
||||
Tensor tensor_a =
|
||||
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b =
|
||||
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementA const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(
|
||||
ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
|
||||
args.scale_d1};
|
||||
} else {
|
||||
auto ptr_Aux = reinterpret_cast<InternalElementB const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(
|
||||
ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0,
|
||||
args.scale_d1};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const &problem_shape,
|
||||
[[maybe_unused]] Arguments const &args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
|
||||
cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
|
||||
cute::make_shape(N, K, L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
|
||||
"minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
|
||||
8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
|
||||
/// performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const &mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the
|
||||
/// contract Returned tuple must contain at least two elements, with the first
|
||||
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
|
||||
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
|
||||
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
|
||||
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) The rest of the
|
||||
/// tensors can be specified as needed by this collective.
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
|
||||
Params const &mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain
|
||||
// mapping Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
} else {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
|
||||
class BlockCoord>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const &mainloop_params, MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
|
||||
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage &shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x =
|
||||
get<0>(typename DispatchPolicy::ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
|
||||
block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a =
|
||||
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b =
|
||||
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux =
|
||||
SwapAB
|
||||
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(
|
||||
cluster_local_block_id.x);
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
|
||||
: gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
|
||||
n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(
|
||||
m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
} else {
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType *tma_barrier =
|
||||
pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
|
||||
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
|
||||
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
|
||||
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
|
||||
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
|
||||
TensorStorage &shared_tensors, Params const &mainloop_params) {
|
||||
static_assert(is_rmem<FrgTensorC>::value,
|
||||
"C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutAux{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.partition_A(sAux);
|
||||
} else {
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
} else {
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
} else {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
|
||||
size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
|
||||
--k_tile_prologue) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum1);
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
|
||||
warpgroup_commit_batch();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum0);
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accum1);
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accum1);
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
|
||||
/// ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
|
||||
// UNLOCK smem_pipe_release, done _computing_ on it
|
||||
pipeline.consumer_release(smem_pipe_release);
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accum0);
|
||||
warpgroup_fence_operand(accum1);
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_release,
|
||||
int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,724 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <int Stages, class ClusterShape, class KernelSchedule,
|
||||
class TileShape_, class ElementA_, class StrideA_, class ElementB_,
|
||||
class StrideB_, class TiledMma_, class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_, class SmemCopyAtomA_, class TransformA_,
|
||||
class GmemTiledCopyB_, class SmemLayoutAtomB_, class SmemCopyAtomB_,
|
||||
class TransformB_,
|
||||
template <class /* ElementCompute */> class Activation_, bool SwapAB_>
|
||||
struct CollectiveMmaGated<
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape, KernelSchedule>,
|
||||
TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_,
|
||||
GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_,
|
||||
GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_,
|
||||
SwapAB_> {
|
||||
static constexpr bool isGated = true;
|
||||
static constexpr bool SwapAB = SwapAB_;
|
||||
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy =
|
||||
MainloopSm90TmaGmmaWarpSpecializedFP8<Stages, ClusterShape,
|
||||
KernelSchedule>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
using Activation = Activation_<ElementAccumulator>;
|
||||
|
||||
using ElementAux = cute::conditional_t<SwapAB, ElementA_, ElementB_>;
|
||||
using ValTypeAux = cute::conditional_t<SwapAB, typename TiledMma::ValTypeA,
|
||||
typename TiledMma::ValTypeB>;
|
||||
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
|
||||
"SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
|
||||
"SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
|
||||
Int<DispatchPolicy::Stages>{}),
|
||||
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
|
||||
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
|
||||
using SmemLayoutAux = cute::conditional_t<SwapAB, SmemLayoutA, SmemLayoutB>;
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2,
|
||||
"Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator,
|
||||
typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for "
|
||||
"this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
|
||||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA,
|
||||
cute::cosize_v<SmemLayoutA>>
|
||||
smem_A;
|
||||
cute::array_aligned<typename TiledMma::ValTypeB,
|
||||
cute::cosize_v<SmemLayoutB>>
|
||||
smem_B;
|
||||
cute::array_aligned<ValTypeAux, cute::cosize_v<SmemLayoutAux>> smem_Aux;
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const *ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const *ptr_B0;
|
||||
ElementB const *ptr_B1;
|
||||
StrideB dB;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const *>(nullptr),
|
||||
repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_, _, 0),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const *>(nullptr),
|
||||
repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_, _, 0),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any
|
||||
using TMA_Aux = cute::conditional_t<SwapAB, TMA_A, TMA_B>;
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
TMA_Aux tma_load_aux;
|
||||
float scale_d0 = 1.0f;
|
||||
float scale_d1 = 1.0f;
|
||||
uint32_t mma_promotion_interval = 4;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const &problem_shape,
|
||||
Arguments const &args, void *workspace) {
|
||||
(void)workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
|
||||
// only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const *>(args.ptr_A);
|
||||
auto ptr_B0 = reinterpret_cast<ElementB const *>(args.ptr_B0);
|
||||
|
||||
Tensor tensor_a =
|
||||
make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA));
|
||||
Tensor tensor_b =
|
||||
make_tensor(ptr_B0, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
if constexpr (SwapAB) {
|
||||
auto ptr_Aux = reinterpret_cast<ElementA const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyA{}, tensor_aux, SmemLayoutA{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<1>(
|
||||
ClusterShape{})); // mcast along N mode for this M load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux,
|
||||
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
} else {
|
||||
auto ptr_Aux = reinterpret_cast<ElementB const *>(args.ptr_B1);
|
||||
Tensor tensor_aux =
|
||||
make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB));
|
||||
typename Params::TMA_Aux tma_load_aux = make_tma_copy(
|
||||
GmemTiledCopyB{}, tensor_aux, SmemLayoutB{}(_, _, cute::Int<0>{}),
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
|
||||
size<0>(
|
||||
ClusterShape{})); // mcast along M mode for this N load, if any
|
||||
return {tma_load_a, tma_load_b, tma_load_aux,
|
||||
args.scale_d0, args.scale_d1, args.mma_promotion_interval};
|
||||
}
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool can_implement(ProblemShape const &problem_shape,
|
||||
[[maybe_unused]] Arguments const &args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_A>(
|
||||
cute::make_shape(M, K, L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B =
|
||||
tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable =
|
||||
implementable &&
|
||||
cutlass::detail::check_alignment<min_tma_aligned_elements_B>(
|
||||
cute::make_shape(N, K, L), StrideB{});
|
||||
/* MMA promotion interval should be a multiple of 4, since each mainloop
|
||||
* iteration would issue 4 MMA instructions. */
|
||||
implementable = implementable && (args.mma_promotion_interval % 4 == 0);
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the "
|
||||
"minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytes =
|
||||
(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementA>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementB>::value)) /
|
||||
8 +
|
||||
(size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) *
|
||||
static_cast<uint32_t>(sizeof_bits<ElementAux>::value)) /
|
||||
8;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
|
||||
/// performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const &mainloop_params) {
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(
|
||||
mainloop_params.tma_load_aux.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the
|
||||
/// contract Returned tuple must contain at least two elements, with the first
|
||||
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
|
||||
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
|
||||
/// tile so it has shape (BLK_N,BLK_K,n,k,l) gAux_xkl - The tma tensor, A/B
|
||||
/// after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const &problem_shape_MNKL,
|
||||
Params const &mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M, N, K, L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain
|
||||
// mapping Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(M, K, L)); // (m,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
} else {
|
||||
Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(
|
||||
make_shape(N, K, L)); // (n,k,l)
|
||||
Tensor gAux_xkl = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _),
|
||||
Step<X, _1, _1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <class TensorA, class TensorB, class TensorAux, class KTileIterator,
|
||||
class BlockCoord>
|
||||
CUTLASS_DEVICE void
|
||||
load(Params const &mainloop_params, MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorAux> const &load_inputs,
|
||||
BlockCoord const &blk_coord, KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx, uint32_t block_rank_in_cluster,
|
||||
TensorStorage &shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
if (lane_predicate) {
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x,
|
||||
block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
auto block_tma_a =
|
||||
mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b =
|
||||
mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
auto block_tma_aux =
|
||||
SwapAB
|
||||
? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y)
|
||||
: mainloop_params.tma_load_aux.get_slice(
|
||||
cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k)
|
||||
Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord)
|
||||
: gAux_xkl(_, _, n_coord, _, l_coord);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
Tensor tAuxgAux = block_tma_aux.partition_S(gAux);
|
||||
Tensor tAuxsAux = block_tma_aux.partition_D(sAux);
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
uint16_t mcast_mask_aux = 0;
|
||||
|
||||
// Issue TmaLoads
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,
|
||||
n, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout =
|
||||
Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) ->
|
||||
// block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(
|
||||
m, cluster_local_block_id.y, Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (SwapAB) {
|
||||
mcast_mask_aux = mcast_mask_a;
|
||||
} else {
|
||||
mcast_mask_aux = mcast_mask_b;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType *tma_barrier =
|
||||
pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
int write_stage = smem_pipe_write.index();
|
||||
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a),
|
||||
tAgA(_, _, _, *k_tile_iter), tAsA(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b),
|
||||
tBgB(_, _, _, *k_tile_iter), tBsB(_, _, _, write_stage));
|
||||
copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux),
|
||||
tAuxgAux(_, _, _, *k_tile_iter), tAuxsAux(_, _, _, write_stage));
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <class FrgTensorC>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline, PipelineState smem_pipe_read,
|
||||
FrgTensorC &accum0, FrgTensorC &accum1, int k_tile_count, int thread_idx,
|
||||
TensorStorage &shared_tensors, Params const &mainloop_params) {
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value,
|
||||
"C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3,
|
||||
"Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for "
|
||||
"smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()),
|
||||
SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()),
|
||||
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()),
|
||||
SmemLayoutAux{});
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
auto tCsAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.partition_A(sAux);
|
||||
} else {
|
||||
return thread_mma.partition_B(sAux);
|
||||
}
|
||||
}();
|
||||
auto tCrAux = [&]() -> auto {
|
||||
if constexpr (SwapAB) {
|
||||
return thread_mma.make_fragment_A(tCsAux);
|
||||
} else {
|
||||
return thread_mma.make_fragment_B(tCsAux);
|
||||
}
|
||||
}();
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
if constexpr (SwapAB) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE
|
||||
} else {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE
|
||||
}
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} ==
|
||||
size<2>(sAux)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8Accumulation accumulation0(
|
||||
accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
GmmaFP8Accumulation accumulation1(
|
||||
accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0;
|
||||
--k_tile_prologue) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation0.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; k_tile_count > 0; --k_tile_count) {
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips
|
||||
// from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
if (accumulation0.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation0());
|
||||
if constexpr (SwapAB) {
|
||||
cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage),
|
||||
tCrB(_, _, k_block, read_stage), accumulation1());
|
||||
} else {
|
||||
cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage),
|
||||
tCrAux(_, _, k_block, read_stage), accumulation1());
|
||||
}
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to
|
||||
/// ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
|
||||
accumulation0.promote_if_needed();
|
||||
accumulation1.promote_if_needed();
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation0.promote_residue_if_needed();
|
||||
accumulation1.promote_residue_if_needed();
|
||||
|
||||
warpgroup_fence_operand(accumulation0());
|
||||
warpgroup_fence_operand(accumulation1());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_release,
|
||||
int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release,
|
||||
// done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,71 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
* Stateless universal device GEMM kernel type that treats GEMM as
|
||||
* a composition of a collective mainloop and a collective epilogue.
|
||||
*
|
||||
* Supports both the 2.x and 3.x APIs based on whether the first type is
|
||||
* a cute::tuple<> or not.
|
||||
* 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h
|
||||
* 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp
|
||||
*
|
||||
* In the following declaration, the name preceding the 'Or' refers to
|
||||
* 3.x API type argument order, and the name succeeding the 'Or' refers to
|
||||
* 2.x API type argument order. Template arguments without two names
|
||||
* belong to the 3.x API only.
|
||||
**/
|
||||
template <class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l)
|
||||
class CollectiveMainloopOrEpilogue_,
|
||||
class CollectiveEpilogueOrThreadblockSwizzle_,
|
||||
class TileScheduler_ = void, class Enable = void>
|
||||
class GemmUniversalGated;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp"
|
||||
#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp"
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -130,6 +130,15 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::RowMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
|
||||
@@ -0,0 +1,705 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<
|
||||
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedCooperative,
|
||||
typename CollectiveMainloop_::
|
||||
DispatchPolicy::Schedule> &&
|
||||
CollectiveMainloop_::isGated>> {
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or
|
||||
cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
|
||||
ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups =
|
||||
CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
||||
static constexpr uint32_t MaxThreadsPerBlock =
|
||||
CUTE_STATIC_V(size(TiledMma{})) +
|
||||
(NumLoadWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load
|
||||
// threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
using MainloopPipelineStorage =
|
||||
typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage =
|
||||
typename CollectiveEpilogue::PipelineStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
void *workspace{nullptr};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the
|
||||
// aliased type.
|
||||
static Params to_underlying_arguments(Arguments const &args,
|
||||
void *workspace) {
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments "
|
||||
"KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
"to_underlying_arguments(): Setting persistent grid SM count to "
|
||||
<< sm_count);
|
||||
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void *scheduler_workspace = workspace_ptr;
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *mainloop_workspace = nullptr;
|
||||
// Precompute the sub tiles numbers in epilogue, pass into tile scheduler.
|
||||
// Therefore it will be used in separate reduction scheme for streamk case,
|
||||
// NumEpilogueSubTiles default value is 1, which means subtile will not be
|
||||
// used, therefore separate reduction will not be enabled.
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
|
||||
args.scheduler, scheduler_workspace, NumEpilogueSubTiles);
|
||||
|
||||
return {args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(
|
||||
args.problem_shape, args.epilogue, epilogue_workspace),
|
||||
hw_info,
|
||||
scheduler,
|
||||
workspace};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const &args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched &&
|
||||
cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
|
||||
"meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &=
|
||||
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &=
|
||||
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
size_t workspace_size = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
workspace_size +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
|
||||
args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const &args, void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
constexpr uint32_t NumEpilogueSubTiles =
|
||||
CollectiveEpilogue::get_store_pipe_increment(TileShape{});
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream,
|
||||
args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups,
|
||||
NumEpilogueSubTiles);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
|
||||
stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const ¶ms) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread
|
||||
// blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order =
|
||||
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
|
||||
TileShape{}, ClusterShape{},
|
||||
params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, char *smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting "
|
||||
"sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(
|
||||
size(TiledMma{}) == 256,
|
||||
"Cooperative kernel must have TiledMMA operating using 256 threads.");
|
||||
static_assert(size<0>(TileShape{}) >= 128,
|
||||
"Cooperative kernel requires Tile Size to be greater than or "
|
||||
"equal to 128 along the M-dimension.");
|
||||
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
|
||||
/* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the
|
||||
* same tile */
|
||||
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
|
||||
enum class ProducerWarpRole {
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage &shared_storage =
|
||||
*reinterpret_cast<SharedStorage *>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
int mma_thread_idx = thread_idx % size(TiledMma{});
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = size(TiledMma{});
|
||||
mainloop_pipeline_params.transaction_bytes =
|
||||
CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
|
||||
mainloop_pipeline_params,
|
||||
ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Epilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = size(TiledMma{});
|
||||
epi_load_pipeline_params.transaction_bytes =
|
||||
CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
|
||||
epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id =
|
||||
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
|
||||
params_load_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
|
||||
// scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = []() {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
} else {
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only
|
||||
// rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread
|
||||
// block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and
|
||||
// compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors
|
||||
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
|
||||
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
|
||||
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs =
|
||||
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(
|
||||
cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid()) {
|
||||
if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Get the number of K tiles to compute for this work as well as the
|
||||
// starting K tile offset of the work.
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
|
||||
work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
auto work_k_tile_start =
|
||||
TileScheduler::get_work_k_tile_start(work_tile_info);
|
||||
auto k_tile_iter = cute::make_coord_iterator(
|
||||
idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(
|
||||
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx,
|
||||
block_rank_in_cluster, shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(work_k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline,
|
||||
mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
|
||||
collective_epilogue.is_producer_load_needed()) {
|
||||
while (work_tile_info.is_valid()) {
|
||||
if (!TileScheduler::requires_separate_reduction(params.scheduler)) {
|
||||
load_order_barrier.wait();
|
||||
}
|
||||
if (TileScheduler::compute_epilogue(work_tile_info,
|
||||
params.scheduler)) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline, epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
// Do we potentially issue tail arrives for TMA stores, if epilogue load
|
||||
// is waiting for it
|
||||
bool do_store_tail = false;
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
auto work_k_tile_count = TileScheduler::get_work_k_tile_count(
|
||||
work_tile_info, problem_shape_MNKL, blk_shape);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
//
|
||||
// MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead.
|
||||
auto accumulators0 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
auto accumulators1 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) {
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, work_k_tile_count, mma_thread_idx,
|
||||
shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Make sure the math instructions are done and free buffers before
|
||||
// entering the epilogue
|
||||
collective_mainloop.mma_tail(mainloop_pipeline,
|
||||
mainloop_pipe_consumer_state,
|
||||
work_k_tile_count);
|
||||
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(work_k_tile_count);
|
||||
}
|
||||
// Index of warp group within consumer warp groups
|
||||
int consumer_warp_group_idx =
|
||||
canonical_warp_group_idx() - NumLoadWarpGroups;
|
||||
|
||||
// Perform reduction across splits, if needed
|
||||
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators0,
|
||||
NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
TileScheduler::fixup(params.scheduler, work_tile_info, accumulators1,
|
||||
NumMmaWarpGroups, consumer_warp_group_idx);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++) {
|
||||
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
|
||||
(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) {
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue,
|
||||
work_tile_info.reduction_subtile_idx());
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next;
|
||||
do_store_tail = true;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
work_tile_info = fetch_next_work(work_tile_info, scheduler);
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
if (do_store_tail) {
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline,
|
||||
epi_store_pipe_producer_state);
|
||||
}
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
// Kernel helper function to get next work unit
|
||||
CUTLASS_DEVICE
|
||||
typename TileScheduler::WorkTileInfo
|
||||
fetch_next_work(typename TileScheduler::WorkTileInfo &work_tile_info,
|
||||
TileScheduler &scheduler) const {
|
||||
// Check whether we should continue on with the current work unit. If this
|
||||
// is the case, the work unit will have been updated in
|
||||
// continue_current_work to reflect the new tile to be computed.
|
||||
if (scheduler.continue_current_work(work_tile_info)) {
|
||||
return work_tile_info;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
return scheduler.get_current_work();
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -0,0 +1,680 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cutlass/arch/mma_sm90.h"
|
||||
#include "cutlass/arch/reg_reconfig.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/collective/detail.hpp"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp"
|
||||
#include "cutlass/kernel_hardware_info.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/workspace.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cute/util/debug.hpp"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::kernel {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class ProblemShape_, class CollectiveMainloop_,
|
||||
class CollectiveEpilogue_, class TileScheduler_>
|
||||
class GemmUniversalGated<
|
||||
ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_,
|
||||
cute::enable_if_t<cute::is_base_of_v<KernelTmaWarpSpecializedPingpong,
|
||||
typename CollectiveMainloop_::
|
||||
DispatchPolicy::Schedule> &&
|
||||
CollectiveMainloop_::isGated>> {
|
||||
public:
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using ProblemShape = ProblemShape_;
|
||||
static_assert(cute::rank(ProblemShape{}) == 3 or
|
||||
cute::rank(ProblemShape{}) == 4,
|
||||
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
||||
// Mainloop derived types
|
||||
using CollectiveMainloop = CollectiveMainloop_;
|
||||
using TileShape = typename CollectiveMainloop::TileShape;
|
||||
using TiledMma = typename CollectiveMainloop::TiledMma;
|
||||
using ArchTag = typename CollectiveMainloop::ArchTag;
|
||||
using ElementA = typename CollectiveMainloop::ElementA;
|
||||
using StrideA = typename CollectiveMainloop::StrideA;
|
||||
using ElementB = typename CollectiveMainloop::ElementB;
|
||||
using StrideB = typename CollectiveMainloop::StrideB;
|
||||
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
||||
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
||||
using ClusterShape = typename DispatchPolicy::ClusterShape;
|
||||
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
||||
using MainloopParams = typename CollectiveMainloop::Params;
|
||||
using Activation = typename CollectiveMainloop::Activation;
|
||||
static_assert(ArchTag::kMinComputeCapability >= 90);
|
||||
|
||||
// Epilogue derived types
|
||||
using CollectiveEpilogue = CollectiveEpilogue_;
|
||||
using ElementC = typename CollectiveEpilogue::ElementC;
|
||||
using StrideC = typename CollectiveEpilogue::StrideC;
|
||||
using ElementD = typename CollectiveEpilogue::ElementD;
|
||||
using StrideD = typename CollectiveEpilogue::StrideD;
|
||||
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
||||
using EpilogueParams = typename CollectiveEpilogue::Params;
|
||||
|
||||
static_assert(
|
||||
!cute::is_same_v<TileScheduler_, StreamKScheduler>,
|
||||
"Ping-pong kernel does not currently support stream-K scheduler.");
|
||||
using TileSchedulerTag = TileScheduler_;
|
||||
using TileScheduler =
|
||||
typename detail::TileSchedulerSelector<TileScheduler_, ArchTag, TileShape,
|
||||
ClusterShape>::Scheduler;
|
||||
using TileSchedulerArguments = typename TileScheduler::Arguments;
|
||||
using TileSchedulerParams = typename TileScheduler::Params;
|
||||
|
||||
static constexpr uint32_t NumLoadWarpGroups = 1;
|
||||
static constexpr uint32_t NumMmaWarpGroups = 2;
|
||||
static constexpr uint32_t MaxThreadsPerBlock =
|
||||
CUTE_STATIC_V(size(TiledMma{})) +
|
||||
(NumMmaWarpGroups * NumThreadsPerWarpGroup);
|
||||
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
/// Register requirement for Load and Math WGs
|
||||
static constexpr uint32_t LoadRegisterRequirement = 40;
|
||||
static constexpr uint32_t MmaRegisterRequirement = 232;
|
||||
|
||||
// 1 stage ordered sequence between mainloop and epilogue producer load
|
||||
// threads
|
||||
using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>;
|
||||
|
||||
// Order Sequence barrier with two stages: one for Mainloop and one for
|
||||
// Epilogue
|
||||
static constexpr uint32_t StagesPerMathWarpGroup = 2;
|
||||
using MathWarpGroupOrderBarrier =
|
||||
cutlass::OrderedSequenceBarrier<StagesPerMathWarpGroup, NumMmaWarpGroups>;
|
||||
|
||||
// Kernel level shared memory storage
|
||||
struct SharedStorage {
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage;
|
||||
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
|
||||
|
||||
MainloopTensorStorage mainloop;
|
||||
EpilogueTensorStorage epilogue;
|
||||
} tensors;
|
||||
|
||||
struct PipelineStorage : cute::aligned_struct<16> {
|
||||
using MainloopPipelineStorage =
|
||||
typename CollectiveMainloop::PipelineStorage;
|
||||
using EpiLoadPipelineStorage =
|
||||
typename CollectiveEpilogue::PipelineStorage;
|
||||
using MathWarpGroupOrderBarrierStorage =
|
||||
typename MathWarpGroupOrderBarrier::SharedStorage;
|
||||
|
||||
alignas(16) MainloopPipelineStorage mainloop;
|
||||
alignas(16) EpiLoadPipelineStorage epi_load;
|
||||
alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order;
|
||||
alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order;
|
||||
} pipelines;
|
||||
};
|
||||
|
||||
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
||||
|
||||
// Device side arguments
|
||||
struct Arguments {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopArguments mainloop{};
|
||||
EpilogueArguments epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerArguments scheduler{};
|
||||
};
|
||||
|
||||
// Kernel entry point API
|
||||
struct Params {
|
||||
GemmUniversalMode mode{};
|
||||
ProblemShape problem_shape{};
|
||||
MainloopParams mainloop{};
|
||||
EpilogueParams epilogue{};
|
||||
KernelHardwareInfo hw_info{};
|
||||
TileSchedulerParams scheduler{};
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// Convert to underlying arguments. In this case, a simple copy for the
|
||||
// aliased type.
|
||||
static Params to_underlying_arguments(Arguments const &args,
|
||||
void *workspace) {
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
||||
|
||||
(void)workspace;
|
||||
auto problem_shape = args.problem_shape;
|
||||
// if constexpr (detail::IF_SWAP_AB<CollectiveMainloop>::value) {
|
||||
// // swap M/N
|
||||
// get<0>(problem_shape) = get<1>(args.problem_shape);
|
||||
// get<1>(problem_shape) = get<0>(args.problem_shape);
|
||||
// }
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
if (sm_count <= 0) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments "
|
||||
"KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
args.hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
"to_underlying_arguments(): Setting persistent grid SM count to "
|
||||
<< sm_count);
|
||||
KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
||||
|
||||
// Calculate workspace pointers
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
void *scheduler_workspace = workspace_ptr;
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *epilogue_workspace = workspace_ptr + workspace_offset;
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
|
||||
void *mainloop_workspace = nullptr;
|
||||
|
||||
return {args.mode,
|
||||
problem_shape,
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
args.problem_shape, args.mainloop, mainloop_workspace),
|
||||
CollectiveEpilogue::to_underlying_arguments(
|
||||
args.problem_shape, args.epilogue, epilogue_workspace),
|
||||
hw_info,
|
||||
TileScheduler::to_underlying_arguments(
|
||||
problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info,
|
||||
args.scheduler, scheduler_workspace)};
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const &args) {
|
||||
bool implementable = (args.mode == GemmUniversalMode::kGemm) or
|
||||
(args.mode == GemmUniversalMode::kBatched &&
|
||||
cute::rank(ProblemShape{}) == 4);
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't "
|
||||
"meet the requirements.\n");
|
||||
return implementable;
|
||||
}
|
||||
implementable &=
|
||||
CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &=
|
||||
CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const &args) {
|
||||
size_t workspace_size = 0;
|
||||
workspace_size +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape,
|
||||
args.epilogue);
|
||||
workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment);
|
||||
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
static cutlass::Status
|
||||
initialize_workspace(Arguments const &args, void *workspace = nullptr,
|
||||
cudaStream_t stream = nullptr,
|
||||
CudaHostAdapter *cuda_adapter = nullptr) {
|
||||
Status status = Status::kSuccess;
|
||||
uint8_t *workspace_ptr = reinterpret_cast<uint8_t *>(workspace);
|
||||
size_t workspace_offset = 0;
|
||||
|
||||
status = TileScheduler::template initialize_workspace<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, workspace_ptr + workspace_offset, stream,
|
||||
args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset +=
|
||||
TileScheduler::template get_workspace_size<ProblemShape,
|
||||
ElementAccumulator>(
|
||||
args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
status = CollectiveEpilogue::initialize_workspace(
|
||||
args.problem_shape, args.epilogue, workspace_ptr + workspace_offset,
|
||||
stream, cuda_adapter);
|
||||
workspace_offset += CollectiveEpilogue::get_workspace_size(
|
||||
args.problem_shape, args.epilogue);
|
||||
workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Computes the kernel launch grid shape based on runtime parameters
|
||||
static dim3 get_grid_shape(Params const ¶ms) {
|
||||
// Given device SM count, set grid size s.t. we do not launch more thread
|
||||
// blocks than we can run concurrently
|
||||
TileSchedulerArguments args{};
|
||||
if constexpr (!std::is_const_v<decltype(args.max_swizzle_size)>) {
|
||||
args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_;
|
||||
}
|
||||
args.raster_order =
|
||||
params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN
|
||||
? TileScheduler::RasterOrderOptions::AlongN
|
||||
: TileScheduler::RasterOrderOptions::AlongM;
|
||||
return TileScheduler::get_grid_shape(params.scheduler, params.problem_shape,
|
||||
TileShape{}, ClusterShape{},
|
||||
params.hw_info, args);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const ¶ms, char *smem_buf) {
|
||||
using namespace cute;
|
||||
using X = Underscore;
|
||||
|
||||
// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a.
|
||||
#if !defined(__CUDA_ARCH_FEAT_SM90_ALL)
|
||||
printf("ERROR : Arch conditional MMA instruction used without targeting "
|
||||
"sm90a compute capability. Aborting.\n");
|
||||
#else
|
||||
|
||||
// Preconditions
|
||||
static_assert(cute::rank(StrideA{}) == 3,
|
||||
"StrideA must be rank-3: [M, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideB{}) == 3,
|
||||
"StrideB must be rank-3: [N, K, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideC{}) == 3,
|
||||
"StrideC must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
static_assert(cute::rank(StrideD{}) == 3,
|
||||
"StrideD must be rank-3: [M, N, L]. If batch mode is not "
|
||||
"needed, set L stride to Int<0>.");
|
||||
|
||||
enum class WarpGroupRole { Producer = 0, Consumer0 = 1, Consumer1 = 2 };
|
||||
enum class ProducerWarpRole {
|
||||
Mainloop = 0,
|
||||
Warp1 = 1,
|
||||
Epilogue = 2,
|
||||
Warp3 = 3
|
||||
};
|
||||
|
||||
// Kernel level shared memory storage
|
||||
SharedStorage &shared_storage =
|
||||
*reinterpret_cast<SharedStorage *>(smem_buf);
|
||||
|
||||
int thread_idx = int(threadIdx.x);
|
||||
int lane_idx = canonical_lane_idx();
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup;
|
||||
int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup;
|
||||
auto warp_group_role = WarpGroupRole(canonical_warp_group_idx());
|
||||
auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group);
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
uint32_t block_rank_in_cluster = cute::block_rank_in_cluster();
|
||||
|
||||
// Issue Tma Descriptor Prefetch from a single thread
|
||||
if ((warp_idx == 0) && lane_predicate) {
|
||||
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
||||
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
||||
}
|
||||
|
||||
// Mainloop Load pipeline
|
||||
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
||||
typename MainloopPipeline::Params mainloop_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
mainloop_pipeline_params.role =
|
||||
MainloopPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0;
|
||||
mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup;
|
||||
mainloop_pipeline_params.transaction_bytes =
|
||||
CollectiveMainloop::TmaTransactionBytes;
|
||||
MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop,
|
||||
mainloop_pipeline_params,
|
||||
ClusterShape{});
|
||||
|
||||
// Epilogue Load pipeline
|
||||
using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline;
|
||||
typename EpiLoadPipeline::Params epi_load_pipeline_params;
|
||||
if (warp_group_role == WarpGroupRole::Producer &&
|
||||
producer_warp_role == ProducerWarpRole::Epilogue) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer;
|
||||
}
|
||||
if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer;
|
||||
}
|
||||
epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster();
|
||||
epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp;
|
||||
epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup;
|
||||
epi_load_pipeline_params.transaction_bytes =
|
||||
CollectiveEpilogue::TmaTransactionBytes;
|
||||
EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load,
|
||||
epi_load_pipeline_params);
|
||||
|
||||
// Epilogue Store pipeline
|
||||
using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline;
|
||||
typename EpiStorePipeline::Params epi_store_pipeline_params;
|
||||
epi_store_pipeline_params.always_wait = true;
|
||||
EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params);
|
||||
|
||||
typename LoadWarpOrderBarrier::Params params_load_order_barrier;
|
||||
params_load_order_barrier.group_id =
|
||||
producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1;
|
||||
params_load_order_barrier.group_size = NumThreadsPerWarp;
|
||||
LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order,
|
||||
params_load_order_barrier);
|
||||
|
||||
typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier;
|
||||
// DMA Load WG will not participate in these Ordered Barrier syncs
|
||||
params_math_wg_order_barrier.group_id =
|
||||
canonical_warp_group_idx() - static_cast<int>(WarpGroupRole::Consumer0);
|
||||
params_math_wg_order_barrier.group_size =
|
||||
NumThreadsPerWarpGroup; // Number of threads / participants in a group
|
||||
MathWarpGroupOrderBarrier math_wg_order_barrier(
|
||||
shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier);
|
||||
|
||||
// Initialize starting pipeline states for the collectives
|
||||
// Epilogue store pipe is producer-only (consumer is TMA unit, waits via
|
||||
// scoreboarding)
|
||||
typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state;
|
||||
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state;
|
||||
|
||||
// For the DMA Load (producer) we start with an opposite phase
|
||||
// i.e., we skip all waits since we know that the buffer is indeed empty
|
||||
PipelineState mainloop_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<MainloopPipeline>();
|
||||
PipelineState epi_load_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiLoadPipeline>();
|
||||
PipelineState epi_store_pipe_producer_state =
|
||||
cutlass::make_producer_start_state<EpiStorePipeline>();
|
||||
|
||||
auto cluster_wait_fn = [&]() {
|
||||
// We need this to guarantee that the Pipeline init is visible
|
||||
// To all producers and consumer thread blocks in the Cluster
|
||||
if constexpr (size(ClusterShape{}) > 1) {
|
||||
cute::cluster_arrive_relaxed();
|
||||
return []() { cute::cluster_wait(); };
|
||||
} else {
|
||||
__syncthreads();
|
||||
return []() {}; // do nothing
|
||||
}
|
||||
}();
|
||||
|
||||
// Separate out problem shape for convenience
|
||||
// Optionally append 1s until problem shape is rank-4 in case it is only
|
||||
// rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
||||
|
||||
// Get the appropriate blocks for this thread block -- potential for thread
|
||||
// block locality
|
||||
TiledMma tiled_mma;
|
||||
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
||||
|
||||
// In a warp specialized kernel, collectives expose data movement and
|
||||
// compute operations separately
|
||||
CollectiveMainloop collective_mainloop;
|
||||
CollectiveEpilogue collective_epilogue(params.epilogue,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Prepare and partition the input tensors. Expects a tuple of tensors
|
||||
// where: get<0>(load_inputs) is the tma tensor A after local tiling so that
|
||||
// it has shape (BLK_M,BLK_K,m,k,l) get<1>(load_inputs) is the tma tensor B
|
||||
// after local tiling so that it has shape (BLK_N,BLK_K,n,k,l)
|
||||
auto load_inputs =
|
||||
collective_mainloop.load_init(problem_shape_MNKL, params.mainloop);
|
||||
static_assert(
|
||||
cute::tuple_size_v<decltype(load_inputs)> >= 3,
|
||||
"Output of load_init must have at least three elements (A, B, Aux)");
|
||||
|
||||
// Extract out partitioned A and B.
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
Tensor gAux_xkl = get<2>(load_inputs);
|
||||
|
||||
// Get pipeline stage increments from tensor shapes
|
||||
auto k_tile_count = size<3>(gA_mkl);
|
||||
auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape);
|
||||
auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape);
|
||||
|
||||
TileScheduler scheduler{params.scheduler};
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Consumer1) {
|
||||
// Advance 2nd Math WG to the next work tile for the startup
|
||||
scheduler.advance_to_next_work();
|
||||
// Advance 2nd Math WG pipeline states to the end of 1st Math WG
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count);
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
}
|
||||
auto work_tile_info = scheduler.get_current_work();
|
||||
|
||||
// Wait for all thread blocks in the Cluster
|
||||
cluster_wait_fn();
|
||||
|
||||
if (warp_group_role == WarpGroupRole::Producer) {
|
||||
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
||||
|
||||
// Mainloop Producer Warp
|
||||
if (producer_warp_role == ProducerWarpRole::Mainloop) {
|
||||
bool do_load_order_arrive = true;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl));
|
||||
|
||||
collective_mainloop.load(
|
||||
params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state,
|
||||
load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx,
|
||||
block_rank_in_cluster, shared_storage.tensors.mainloop);
|
||||
// Update starting pipeline state for the next tile
|
||||
mainloop_pipe_producer_state.advance(k_tile_count);
|
||||
|
||||
// Signal for the epilogue load warp to begin
|
||||
if (do_load_order_arrive) {
|
||||
load_order_barrier.arrive();
|
||||
do_load_order_arrive = false;
|
||||
}
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_mainloop.load_tail(mainloop_pipeline,
|
||||
mainloop_pipe_producer_state);
|
||||
} // Mainloop Producer Warp End
|
||||
|
||||
// Epilogue Producer Warp
|
||||
else if (producer_warp_role == ProducerWarpRole::Epilogue &&
|
||||
collective_epilogue.is_producer_load_needed()) {
|
||||
load_order_barrier.wait();
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
epi_load_pipe_producer_state = collective_epilogue.load(
|
||||
epi_load_pipeline, epi_load_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work();
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
|
||||
// Make sure all Consumer Warp Groups have been waited upon
|
||||
collective_epilogue.load_tail(epi_load_pipeline,
|
||||
epi_load_pipe_producer_state);
|
||||
} // Epilogue Producer Warp End
|
||||
} // Producer Warp Group End
|
||||
|
||||
else if (warp_group_role == WarpGroupRole::Consumer0 ||
|
||||
warp_group_role == WarpGroupRole::Consumer1) {
|
||||
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
||||
|
||||
float scale_d0 = params.mainloop.scale_d0;
|
||||
float scale_d1 = params.mainloop.scale_d1;
|
||||
while (work_tile_info.is_valid()) {
|
||||
// Compute m_coord, n_coord, l_coord with the post-tiled m-shape and
|
||||
// n-shape
|
||||
auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl));
|
||||
auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl));
|
||||
auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl));
|
||||
auto blk_coord = make_coord(m_coord, n_coord, _, l_coord);
|
||||
|
||||
// Allocate the accumulators for the (M,N) blk_shape
|
||||
Tensor accumulators0 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
Tensor accumulators1 = partition_fragment_C(
|
||||
tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
||||
|
||||
// Order two Math WG's MMA one after the other, helps hide Epilogue
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
collective_mainloop.mma(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0,
|
||||
accumulators1, k_tile_count, warp_group_thread_idx,
|
||||
shared_storage.tensors.mainloop, params.mainloop);
|
||||
|
||||
// Cue for next Math WG's MMA to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Make sure the math instructions are done and free buffers before
|
||||
// entering the epilogue
|
||||
collective_mainloop.mma_tail(
|
||||
mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count);
|
||||
// Update starting mainloop pipeline state for the next tile
|
||||
mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups);
|
||||
|
||||
Activation elt_op;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accumulators0); i++) {
|
||||
accumulators0[i] = elt_op(accumulators0[i] * scale_d0) *
|
||||
(scale_d1 * accumulators1[i]);
|
||||
}
|
||||
|
||||
// Order two Math WG's Epilogue one after the other
|
||||
math_wg_order_barrier.wait();
|
||||
|
||||
// Epilogue and write to gD
|
||||
auto [epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipe_producer_state_next] =
|
||||
collective_epilogue.store(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state,
|
||||
problem_shape_MNKL, blk_shape, blk_coord, accumulators0,
|
||||
tiled_mma, warp_group_thread_idx,
|
||||
shared_storage.tensors.epilogue);
|
||||
|
||||
// TMA store pipeline wait is only visible to TMA-issuing warp, so for
|
||||
// multiple-consumer kernels we need to wait for all TMA stores to
|
||||
// complete before issuing consumer order barrier arrives to ensure next
|
||||
// math consumer doesn't overwrite smem of in-flight TMA stores of
|
||||
// current consumer.
|
||||
auto [epi_load_pipe_consumer_state_next_,
|
||||
epi_store_pipe_producer_state_next_] =
|
||||
collective_epilogue.store_tail(
|
||||
epi_load_pipeline, epi_load_pipe_consumer_state_next,
|
||||
epi_store_pipeline, epi_store_pipe_producer_state_next);
|
||||
|
||||
// Update starting load/store pipeline states for the next tile
|
||||
// state has already been incremented by 1 tile in collective calls,
|
||||
// advance once again for ping pong
|
||||
epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_;
|
||||
epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_;
|
||||
epi_load_pipe_consumer_state.advance(c_tile_count);
|
||||
epi_store_pipe_producer_state.advance(d_tile_count);
|
||||
|
||||
// Cue for next Math WG's Epilogue to start
|
||||
math_wg_order_barrier.arrive();
|
||||
|
||||
// Get next work tile
|
||||
scheduler.advance_to_next_work(NumMmaWarpGroups);
|
||||
work_tile_info = scheduler.get_current_work();
|
||||
} // Scheduler work fetch loop
|
||||
} // Consumer Warp Groups End
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::kernel
|
||||
@@ -77,6 +77,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -125,6 +126,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -148,7 +150,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -179,6 +181,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -234,6 +237,7 @@ public:
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -346,6 +350,131 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, El
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -19,13 +19,11 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -197,6 +195,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -244,6 +243,9 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -265,7 +267,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -296,6 +298,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -318,11 +321,11 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
@@ -348,6 +351,131 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount =
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA =
|
||||
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
// w uint8; local_scale uint8;
|
||||
constexpr static int kZippedRowsPerStages =
|
||||
Shape::kK / 4 + (Shape::kK + 127) / 128;
|
||||
|
||||
// code_scale float; code_zp float; super_scale ElementB
|
||||
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
|
||||
sizeof_bits<typename Operator::ElementB>::value / 8;
|
||||
|
||||
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
|
||||
|
||||
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for quanted B operand
|
||||
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
|
||||
|
||||
/// Buffer for unzip B operand
|
||||
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
|
||||
operand_unzip_B;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
typename Operator::ElementB *operand_unzip_B_ptr() {
|
||||
return operand_unzip_B.data();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,807 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/arch/memory_copy_sm80.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
|
||||
// accuracy, where each mainloop iteration first accumulates into a temporary
|
||||
// set of freshly-cleared accumulators, which are subsequently added to the
|
||||
// final accumulator set.
|
||||
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
// Structure encapsulating pipeline state live from one iteration to the next
|
||||
struct PipeState {
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B_[2];
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Warp-level MMA operator
|
||||
Operator warp_mma_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
uint8_t* column_wise_smem_ptr_B_;
|
||||
|
||||
uint8_t* smem_zipped_ptr_B_;
|
||||
int smem_zipped_bytes_per_stage_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
|
||||
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
|
||||
|
||||
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
|
||||
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
|
||||
}
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_read_stage()
|
||||
{
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(
|
||||
IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B)
|
||||
{
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
//iterator_B.add_tile_offset({1, 0});
|
||||
tile_dequanter_B.AddTileOffset({1, 0});
|
||||
|
||||
// Advance shared iterators
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
//smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
|
||||
if (smem_write_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) {
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB, bool InitStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
if (InitStage) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
} else {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B,
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
copy_tiles_and_advance_per_stage_A(iterator_A);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
|
||||
// Move to the next write stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
|
||||
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
typename IteratorA::AccessType zero_A;
|
||||
|
||||
zero_A.clear();
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait until we have at least one completed global fetch stage
|
||||
CUTLASS_DEVICE
|
||||
void gmem_wait()
|
||||
{
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
// Unpack and dequant the first stage of B.
|
||||
int unpack_stage = stage - Base::kStages + 2;
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, unpack_stage);
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
// global->shared fragment copies
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// The second-to-last warp-tile also:
|
||||
// - performs the last warp-tile's share of global->shared fragment copies
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Move to the next global fetch stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_read_stage();
|
||||
|
||||
// Disable global fetching when done with global fetch iterations
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
}
|
||||
|
||||
// The last warp-tile also converts the shared memory fragments used by
|
||||
// the first warp-tile of the next iteration, if necessary (so we can
|
||||
// immediately start issuing MMA instructions at the top of the loop )
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the specified number of threadblock mainloop iterations of matrix
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
|
||||
{
|
||||
PipeState pipe_state;
|
||||
|
||||
// Unpack and dequant the first stage of B.
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
|
||||
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Transform, if necessary, the first warp-tile's shared memory fragments
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[0],
|
||||
pipe_state.warp_transformed_frag_B_[0],
|
||||
pipe_state.warp_loaded_frag_A_[0],
|
||||
pipe_state.warp_loaded_frag_B_[0]);
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
|
||||
int stage = Base::kStages - 1;
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
mac_loop_iter(
|
||||
pipe_state,
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Prepares the class for another prologue.
|
||||
CUTLASS_DEVICE
|
||||
void wind_down()
|
||||
{
|
||||
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
|
||||
|
||||
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
|
||||
// just decrement one tile, but not all iterators implement --() decrement.)
|
||||
#pragma unroll
|
||||
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
smem_read_stage_idx_++;
|
||||
|
||||
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
|
||||
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
|
||||
if (smem_read_stage_idx_ > 1)
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< pre-load and dequantize B to shared memory
|
||||
TileDequanterB tile_dequanter_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Initialize destination accumulators with source accumulators
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
|
||||
int Stages, int NumThreads, WintQuantMethod Method>
|
||||
struct TileDequanter {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
|
||||
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
using UnzipAndDequantFunctor =
|
||||
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
|
||||
|
||||
static constexpr bool kUseSharedMemory = true;
|
||||
|
||||
static constexpr int kRows = Rows;
|
||||
static constexpr int kColumns = Columns;
|
||||
static constexpr int kStages = Stages;
|
||||
|
||||
MmaElementT *out_smem_ptr{nullptr};
|
||||
|
||||
char *pointer{nullptr};
|
||||
int64_t ldm{0};
|
||||
cutlass::MatrixCoord tb_offset;
|
||||
cutlass::MatrixCoord extent;
|
||||
|
||||
ScaleElementT *super_scale_ptr{nullptr};
|
||||
cutlass::MatrixCoord tb_offset_scale;
|
||||
|
||||
QuantArguments quant_args;
|
||||
|
||||
int64_t block_start_rows[kStages];
|
||||
bool need_preload{true};
|
||||
UnzipAndDequantFunctor unzip_functor;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
|
||||
const cutlass::MatrixCoord &extent,
|
||||
const cutlass::MatrixCoord &tb_offset,
|
||||
ScaleElementT *super_scale_ptr,
|
||||
const cutlass::MatrixCoord &tb_offset_scale,
|
||||
const QuantArguments &quant_args)
|
||||
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
|
||||
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
|
||||
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaElementT *GetOutPtr() { return out_smem_ptr; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
|
||||
tb_offset.row() += tile_offset.row() * kRows;
|
||||
tb_offset.column() += tile_offset.column() * kColumns;
|
||||
tb_offset_scale.column() += tile_offset.column() * kColumns;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
|
||||
if (tb_offset.row() >= extent.row() ||
|
||||
tb_offset.column() >= extent.column()) {
|
||||
return;
|
||||
}
|
||||
|
||||
block_start_rows[stage % kStages] = tb_offset.row();
|
||||
|
||||
using ZippedT = typename WeightQuantTraits::WeightType;
|
||||
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
|
||||
tb_offset.column();
|
||||
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
|
||||
(tb_offset.row() / 128) * ldm +
|
||||
tb_offset_scale.column();
|
||||
const float *code_scale_ptr =
|
||||
quant_args.code_scale_ptr + tb_offset_scale.column();
|
||||
const float *code_zp_ptr =
|
||||
quant_args.code_zp_ptr + tb_offset_scale.column();
|
||||
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
|
||||
scale_ptr, &args, ldm, need_preload);
|
||||
need_preload = false;
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int64_t block_start_row = block_start_rows[stage % kStages];
|
||||
if (block_start_row >= extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,447 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename T, int N>
|
||||
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
|
||||
|
||||
template <typename T, WintQuantMethod QuantMethod, int TileRows,
|
||||
int TileColumns, int NumThreads = 128>
|
||||
struct UnzipAndDequantFunctor {
|
||||
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint16_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kZippedGroupSize = 10;
|
||||
static constexpr int32_t kNumPackedValues = 7;
|
||||
|
||||
static constexpr int32_t kWeightMask = 0x7;
|
||||
static constexpr int32_t kLocalScaleMask = 0x1FFF;
|
||||
static constexpr int32_t kBZP = 4;
|
||||
|
||||
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
|
||||
ScaleComputeT scale) {
|
||||
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
|
||||
int32_t value = shifted_value - kBZP;
|
||||
|
||||
ScaleComputeT scaled_value = static_cast<ScaleComputeT>(value) * scale;
|
||||
return static_cast<T>(scaled_value);
|
||||
}
|
||||
|
||||
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {
|
||||
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(super_scale_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
// the last row in group
|
||||
int zipped_row_last = group_id * 10 + 9;
|
||||
int zipped_offset_last = zipped_row_last * in_stride + col;
|
||||
int32_t zipped_value_last =
|
||||
static_cast<int32_t>(in_ptr[zipped_offset_last]);
|
||||
|
||||
ScaleComputeT local_scale =
|
||||
static_cast<ScaleComputeT>(zipped_value_last & kLocalScaleMask);
|
||||
ScaleComputeT scale = local_scale * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row_in_group = 0; zipped_row_in_group < 9;
|
||||
++zipped_row_in_group) {
|
||||
int zipped_row = group_id * 10 + zipped_row_in_group;
|
||||
int zipped_offset = zipped_row * in_stride + col;
|
||||
int32_t zipped_value = static_cast<int32_t>(in_ptr[zipped_offset]);
|
||||
|
||||
int row_in_group = group_id * 64 + zipped_row_in_group * 7;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
T value = Compute(zipped_value, shift_bit, scale);
|
||||
out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value;
|
||||
}
|
||||
}
|
||||
|
||||
int row_in_group_last = group_id * 64 + 63;
|
||||
T value_last = Compute(zipped_value_last, shift_bits[0], scale);
|
||||
out_ptr[row_in_group_last * TileColumns + col] = value_last;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint8_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kPackNum = 4;
|
||||
static constexpr int32_t kWeightMask = 0x3F;
|
||||
static constexpr int32_t kLocalScaleMask = 0xF;
|
||||
static constexpr int32_t kBZP = 32;
|
||||
|
||||
// weight [16, N] uint8_t
|
||||
// local_scale [1, N] uint8_t
|
||||
// code_scale [N] float
|
||||
// code_zp [N] float
|
||||
// super_scale [N] T
|
||||
|
||||
// code_scale, code_zp and super_scale
|
||||
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
|
||||
// zipped weights and local_scale
|
||||
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
|
||||
|
||||
struct Arguments {
|
||||
uint8_t *weight_ptr;
|
||||
uint8_t *local_scale_ptr;
|
||||
float *code_scale_ptr;
|
||||
float *code_zp_ptr;
|
||||
T *super_scale_ptr;
|
||||
|
||||
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
|
||||
|
||||
__device__ explicit Arguments(uint8_t *smem_ptr) {
|
||||
SetZippedPtrs(smem_ptr);
|
||||
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
|
||||
}
|
||||
|
||||
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
|
||||
SetZippedPtrs(zipped_smem_ptr);
|
||||
SetColumnWisePtrs(column_wise_smem_ptr);
|
||||
}
|
||||
|
||||
__device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) {
|
||||
weight_ptr = zipped_smem_ptr;
|
||||
local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns;
|
||||
}
|
||||
|
||||
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
|
||||
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
|
||||
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
|
||||
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
if (need_preload) {
|
||||
if (g_super_scale_ptr) {
|
||||
args->super_scale_ptr[col] = g_super_scale_ptr[col];
|
||||
} else {
|
||||
args->super_scale_ptr[col] = static_cast<T>(1);
|
||||
}
|
||||
|
||||
args->code_scale_ptr[col] = g_code_scale_ptr[col];
|
||||
args->code_zp_ptr[col] = g_code_zp_ptr[col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
|
||||
int local_scale_offset = ls_row_id * in_stride + col;
|
||||
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) {
|
||||
int s_zipped_offset = zipped_row * TileColumns + col;
|
||||
int g_zipped_offset = zipped_row * 4 * in_stride + col;
|
||||
|
||||
args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void LoadAsync(const uint8_t *g_weight_ptr,
|
||||
const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr,
|
||||
const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
constexpr int kBytesPerThread = 16; // 16B per thread
|
||||
|
||||
constexpr int weight_size = TileRows / 4 * TileColumns;
|
||||
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
|
||||
constexpr int code_scale_size = sizeof(float) * TileColumns;
|
||||
constexpr int code_zp_size = sizeof(float) * TileColumns;
|
||||
constexpr int super_scale_size = sizeof(T) * TileColumns;
|
||||
|
||||
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
|
||||
constexpr int total_tasks = total_size / kBytesPerThread;
|
||||
|
||||
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
|
||||
|
||||
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
|
||||
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
|
||||
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
|
||||
|
||||
static_assert(TileColumns % weight_threads == 0,
|
||||
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
|
||||
|
||||
static_assert(TileColumns % local_scale_threads == 0,
|
||||
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
|
||||
|
||||
if (tid < weight_threads) {
|
||||
constexpr int weight_per_thread_size = weight_size / weight_threads;
|
||||
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads;
|
||||
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
|
||||
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
|
||||
}
|
||||
} else if (need_preload) {
|
||||
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads;
|
||||
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
|
||||
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
|
||||
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
|
||||
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
|
||||
if (g_super_scale_ptr) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
|
||||
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
|
||||
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Compute(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
int32_t shift_bits[4] = {9, 6, 3, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(args.super_scale_ptr[col]);
|
||||
ScaleComputeT code_scale =
|
||||
static_cast<ScaleComputeT>(args.code_scale_ptr[col]);
|
||||
ScaleComputeT code_zp = static_cast<ScaleComputeT>(args.code_zp_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
int local_scale_offset = (group_id / 2) * TileColumns + col;
|
||||
int32_t local_scale =
|
||||
static_cast<int32_t>(args.local_scale_ptr[local_scale_offset]);
|
||||
|
||||
ScaleComputeT zipped_value[16];
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col;
|
||||
zipped_value[zipped_row] =
|
||||
static_cast<ScaleComputeT>(args.weight_ptr[zipped_offset]);
|
||||
}
|
||||
|
||||
int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4;
|
||||
int32_t shifted_local_scale =
|
||||
(local_scale >> local_scale_shift) & kLocalScaleMask;
|
||||
ScaleComputeT scale =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
|
||||
static_cast<ScaleComputeT>(0.5)));
|
||||
|
||||
int row = group_id * 64 + zipped_row * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask;
|
||||
|
||||
ScaleComputeT value =
|
||||
static_cast<ScaleComputeT>(shifted_value - kBZP);
|
||||
out_ptr[(row + shift_bit_id) * TileColumns + col] =
|
||||
static_cast<T>(scale * value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
|
||||
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
|
||||
constexpr int RowStride = NumThreads * N / TileColumns;
|
||||
constexpr int kNumIters = kNumWeightsPerThread / N;
|
||||
|
||||
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
|
||||
|
||||
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int begin_col_id = (tid * N) % TileColumns;
|
||||
int begin_row_id = (tid * N) / TileColumns;
|
||||
|
||||
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
|
||||
|
||||
UnzipArray<uint8_t, N> local_scales =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
|
||||
|
||||
UnzipArray<uint8_t, N> zipped_values[2];
|
||||
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
|
||||
zipped_values[0] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
|
||||
UnzipArray<T, N> super_scales =
|
||||
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_scales =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_zps =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
|
||||
UnzipArray<ScaleComputeT, N> scales;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_local_scale =
|
||||
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
|
||||
scales[i] =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int iter_id = 0; iter_id < kNumIters; ++iter_id) {
|
||||
int zipped_row = begin_row_id + iter_id * RowStride;
|
||||
int row = zipped_row * 4;
|
||||
|
||||
if (iter_id < kNumIters - 1) {
|
||||
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
|
||||
zipped_values[(iter_id + 1) & 1] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
}
|
||||
|
||||
UnzipArray<T, N> outs[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
|
||||
+ code_zps[i] + decode_value_zp));
|
||||
|
||||
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
outs[0][i] = static_cast<T>(scales[i] * value_0);
|
||||
outs[1][i] = static_cast<T>(scales[i] * value_1);
|
||||
outs[2][i] = static_cast<T>(scales[i] * value_2);
|
||||
outs[3][i] = static_cast<T>(scales[i] * value_3);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
UnzipArray<T, N> *tmp_out_ptr = reinterpret_cast<UnzipArray<T, N> *>(
|
||||
out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id);
|
||||
*tmp_out_ptr = outs[shift_bit_id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
140
custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h
Normal file
140
custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h
Normal file
@@ -0,0 +1,140 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
enum WintQuantMethod {
|
||||
kNone = 0,
|
||||
kWeightOnlyInt8 = 1,
|
||||
kWeightOnlyInt4 = 2,
|
||||
kWeightOnlyInt25 = 3,
|
||||
kWeightOnlyInt2 = 4
|
||||
};
|
||||
|
||||
// Convert CUDA data type to cutlass data type
|
||||
template <typename T> struct CutlassDataType {
|
||||
using Type = T;
|
||||
};
|
||||
|
||||
template <> struct CutlassDataType<half> {
|
||||
using Type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <> struct CutlassDataType<__nv_bfloat16> {
|
||||
using Type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
template <typename ElementT, WintQuantMethod Method> struct WintQuantTraits;
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kNone> {
|
||||
using WeightType = ElementT;
|
||||
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod = WintQuantMethod::kNone;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt8> {
|
||||
using WeightType = uint8_t;
|
||||
using MmaKernelType = uint8_t;
|
||||
using MmaWeightType = uint8_t;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt8;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt4> {
|
||||
using WeightType = cutlass::uint4b_t;
|
||||
using MmaKernelType = cutlass::uint4b_t;
|
||||
using MmaWeightType = cutlass::uint4b_t;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt4;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) { return dim; }
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt25> {
|
||||
using WeightType = uint16_t;
|
||||
using MmaKernelType = typename CutlassDataType<ElementT>::Type;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt25;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kNumPackedValues = 7;
|
||||
static constexpr int32_t kPackedSize = 10;
|
||||
|
||||
struct Arguments {};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) {
|
||||
return dim * kPackedSize / kGroupSize;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementT>
|
||||
struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
|
||||
using WeightType = uint8_t;
|
||||
using MmaKernelType = cutlass::uint2b_t;
|
||||
using MmaWeightType = typename CutlassDataType<ElementT>::Type;
|
||||
|
||||
static constexpr WintQuantMethod kQuantMethod =
|
||||
WintQuantMethod::kWeightOnlyInt2;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kNumPackedValues = 4;
|
||||
static constexpr int32_t kPackedSize = 16;
|
||||
|
||||
struct Arguments {
|
||||
const uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
const float *code_scale_ptr;
|
||||
const float *code_zp_ptr;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static int64_t CaclPackedDim(int64_t dim) {
|
||||
return dim * kPackedSize / kGroupSize;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass
|
||||
Reference in New Issue
Block a user