mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
Optimize the performance of moe_expert_ffn_wint2 (#2990)
* Change wint2 to ColumnMajor. Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2 * Unify default_wint2x_mma. Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7 * Change wint2 to ColumnMajorTileInterleave. Change-Id: I593cbe36f991c0c5044989d65f0014087587c624 * Enable async copy for B. Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46 * Add wint2x Dequantizer * Remove TileDequanterB related codes. Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Implement Wint2ParamsAccessor to load extra quant params from global memory. Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Use async copy for local_scale. Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893 * Check and correct the load and dequantize of weights. Change-Id: Ie8dca505b39987144964fe6407d465b3b5953790 * Change for performance tuning. Change-Id: I1da026fb1d1533a9d70350c7ba23c27e896cfc29 * Optimize the global memory access size of local_scale reading. Change-Id: I4cbe3a2ef5951723d415c2d3252ce912394beaf5 * Specialize mma_tensor_op for wint2 to enable fine-grained pipeline. Change-Id: Icbb4d48f90a41136f42d6ffff42d68de32f408da * Minor fix. Change-Id: I14d4ac9d267ee05442a3b47f00c26bee13d79e6f * optimizing dequant performance with LOP3 * optimizing dequant performance with LOP3 * Avoid redundant dequantization of local_scale and use bf16 as computing type. Change-Id: I63239ebc8f8e4a92d6281af59840ba50600b4334 * Add Multiplier and remove some logs. Change-Id: Ifa199d81e6aeb472d2247c63f85ef30213684bcd * optimizing dequant performance with LOP3 * Use __byte_perm to implement int8 to float32 conversion for performance improvement. * Use lop3 to optimize the dequantize of local_scale. Change-Id: I6189759970cb5b8dcbef769724784b8a7533b63c * Minor fix and remove some logs. Change-Id: I6279ba9926d5041093b1c6aea200acf2e4c49d46 * Fix stages for test. Change-Id: I6f7b7cac612ef2c678e9d49f5ffa60eb53d3ae29 * Fix stages for test and add clock64 to profile. Change-Id: Iffaf7324beaa910ce9ee56f47ae289de98f1a267 * Use __byte_perm to replace shift-and-or operations for faster integer merging. * Split the uint2b convert. Change-Id: I78da672ce8968e21f685285140ba546a161521b4 * Optimize convert of unscale. Change-Id: I6795da1cdf5e8ab38ddaa9836240921b5312913a * Minor optimization. Change-Id: I1800aec34c3f4621abb02658208108f54da44d88 * Optimize mma pipeline and refine codes. Change-Id: Id3075cf7b88f2813a11ccd1d3b49c62c978f36b8 * Add missing support. Change-Id: Id65b7bc2c25fbb1a5b232c6bc9fb8c9093f691a8 * Accelerate FP16 dequantization performance * Support tile shape as Xx64x64. Change-Id: Ib8fd37e1ba1d06f7d11f2956e7f1367b0a92bcac * Remove debugging codes and minor optimization. Change-Id: I6b79bd56a6e8dd823efc169967ecd3cc9a43baf4 * Fix offset bug. Change-Id: Id7aeb91e99d6f51836f2aff22187b4f79607395e * Fix typo. Change-Id: I19dde93fc1c1f7e19605905c90dc46298e203952 * Restore some codes and remove some debugging logs. Change-Id: I8d44daf82ad1c6f8174134d195e7b3fe9a3afdfb --------- Co-authored-by: baoqiwen <baoqiwen@baidu.com>
This commit is contained in:
@@ -133,10 +133,18 @@ public:
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::RowMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
|
||||
|
||||
public:
|
||||
// using Layout = layout::ColumnMajor;
|
||||
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
|
@@ -18,14 +18,12 @@
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -378,38 +376,23 @@ template <
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -441,38 +424,23 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
|
@@ -19,7 +19,7 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
@@ -379,38 +379,23 @@ template <
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
@@ -442,38 +427,23 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
private:
|
||||
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
|
||||
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
|
@@ -0,0 +1,182 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: uint2b_t, column-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Number of stages
|
||||
int Stages,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_,
|
||||
/// Cache operation of operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Cache operation of operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB>
|
||||
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
layout::RowMajor, uint2b_t, layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
|
||||
Operator_, false, CacheOpA, CacheOpB> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = layout::RowMajor;
|
||||
using ElementB = uint2b_t;
|
||||
using LayoutB = layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
static int const kStages = Stages;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access of B
|
||||
static constexpr int kMaxThreadsForB =
|
||||
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
|
||||
static constexpr int kThreadsForB =
|
||||
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
static int const kWarpThreadArrangementContiguousB =
|
||||
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedB =
|
||||
kWarpSize / kWarpThreadArrangementContiguousB;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementA>::value, Shape::kK>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
sizeof_bits<ElementB>::value, Shape::kK>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
|
||||
IteratorThreadMapA>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
|
||||
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
|
||||
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
|
||||
IteratorThreadMapB>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
|
||||
MatrixShape<0, 0>, WarpCount::kK>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -0,0 +1,246 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_core.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ThreadblockShape, typename ElementT, int GroupSize>
|
||||
struct DefaultQuantParamsIterators {
|
||||
private:
|
||||
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
|
||||
static constexpr int kColumns = ThreadblockShape::kN;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
|
||||
public:
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
|
||||
IteratorThreadMap, kAlignment>;
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <typename ThreadblockShape, int GroupSize>
|
||||
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
|
||||
private:
|
||||
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
|
||||
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
|
||||
|
||||
static constexpr int kRows =
|
||||
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
|
||||
static constexpr int kColumns =
|
||||
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
|
||||
|
||||
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
|
||||
layout::PitchLinearShape<kColumns, kRows>,
|
||||
kColumns / kAlignment, kAlignment>;
|
||||
|
||||
public:
|
||||
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
|
||||
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
|
||||
0, IteratorThreadMap, AccessType>;
|
||||
|
||||
using SmemIterator = Iterator;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
|
||||
struct DefaultWint2xMma;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator, SharedMemoryClear>
|
||||
{
|
||||
public:
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value,
|
||||
"Element B must be uint2b_t");
|
||||
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
using ElementSuperScale = ElementA;
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementCodeScaleZp = float;
|
||||
|
||||
static constexpr int kGroupSize = 64;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
|
||||
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
|
||||
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
|
||||
|
||||
using IteratorShapeB = MatrixShape<
|
||||
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
|
||||
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
|
||||
ThreadMapB::kThreads,
|
||||
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
|
||||
WarpArrangement::kStrided / kColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
private:
|
||||
// Define iterators over tiles from extra quant params for B operand
|
||||
using IteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::Iterator;
|
||||
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
|
||||
|
||||
using IteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
|
||||
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
|
||||
|
||||
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
|
||||
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
|
||||
|
||||
public:
|
||||
using QuantParamsAccessor = Wint2ParamsAccessor<
|
||||
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
|
||||
IteratorLocalScale, SmemIteratorLocalScale,
|
||||
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
|
||||
typename MmaCore::Shape,
|
||||
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
|
||||
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
|
||||
kStages, QuantParamsAccessor, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -63,8 +63,8 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
/// Size of extra quantized params
|
||||
typename QuantParamsShape>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
@@ -93,6 +93,14 @@ public:
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of warp-level GEMM oeprations per load for B
|
||||
static constexpr int kWarpGemmIterationsPerLoadForB =
|
||||
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
|
||||
|
||||
static constexpr int kWarpLoadIterationsForB =
|
||||
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
@@ -104,8 +112,6 @@ public:
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
@@ -130,20 +136,11 @@ public:
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
|
||||
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
// w uint8; local_scale uint8;
|
||||
constexpr static int kZippedRowsPerStages =
|
||||
Shape::kK / 4 + (Shape::kK + 127) / 128;
|
||||
|
||||
// code_scale float; code_zp float; super_scale ElementB
|
||||
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
|
||||
sizeof_bits<typename Operator::ElementB>::value / 8;
|
||||
|
||||
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
|
||||
|
||||
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
|
||||
/// Shape of all quant params in shared memory
|
||||
using QuantParamsShapeB = QuantParamsShape;
|
||||
|
||||
public:
|
||||
//
|
||||
@@ -156,12 +153,8 @@ public:
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for quanted B operand
|
||||
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
|
||||
|
||||
/// Buffer for unzip B operand
|
||||
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
|
||||
operand_unzip_B;
|
||||
/// Buffer for extra quant params of B operand
|
||||
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
|
||||
|
||||
public:
|
||||
//
|
||||
@@ -191,14 +184,6 @@ public:
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
typename Operator::ElementB *operand_unzip_B_ptr() {
|
||||
return operand_unzip_B.data();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
|
@@ -45,7 +45,8 @@
|
||||
|
||||
#include "cutlass_extensions/arch/memory_copy_sm80.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -86,15 +87,15 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Accessor for extra quantized params
|
||||
typename QuantParamsAccessor_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
@@ -107,8 +108,11 @@ public:
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
/// Accessor for extra quantized params
|
||||
using QuantParamsAccessor = QuantParamsAccessor_;
|
||||
using QuantArguments = typename QuantParamsAccessor::Arguments;
|
||||
|
||||
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
|
||||
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
@@ -129,6 +133,18 @@ public:
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
|
||||
using LayoutScale = layout::RowMajor;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
using WarpDequantizer =
|
||||
warp::MmaTensorOpWin2xDequantizer<Operator,
|
||||
typename Base::WarpGemm,
|
||||
Operand::kB,
|
||||
typename WarpTransformedFragmentB::Element,
|
||||
LayoutScale,
|
||||
QuantParamsAccessor::kGroupSize>;
|
||||
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
@@ -174,18 +190,37 @@ public:
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale;
|
||||
using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp;
|
||||
using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale;
|
||||
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B_[2];
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_;
|
||||
WarpTransformedFragmentB warp_frag_B_[2];
|
||||
|
||||
/// channel-wise quant params
|
||||
FragmentCodeScaleZp warp_frag_code_scale_;
|
||||
FragmentCodeScaleZp warp_frag_code_zp_;
|
||||
FragmentSuperScale warp_frag_super_scale_;
|
||||
|
||||
/// group-wise quant params
|
||||
FragmentLocalScale warp_frag_local_scale_;
|
||||
};
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool IsTileInterleaveLayout =
|
||||
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
|
||||
@@ -202,17 +237,18 @@ public:
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Accessor for extra quant params for B
|
||||
QuantParamsAccessor quant_params_accessor_B_;
|
||||
|
||||
// Wint2 unzip operator
|
||||
WarpDequantizer warp_dequantizer_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
uint8_t* column_wise_smem_ptr_B_;
|
||||
|
||||
uint8_t* smem_zipped_ptr_B_;
|
||||
int smem_zipped_bytes_per_stage_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
@@ -226,10 +262,15 @@ public:
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
|
||||
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
|
||||
quant_params_accessor_B_.local_scale_ref(),
|
||||
quant_params_accessor_B_.code_scale_ref(),
|
||||
quant_params_accessor_B_.code_zp_ref(),
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
@@ -250,11 +291,6 @@ public:
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
|
||||
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
|
||||
|
||||
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
|
||||
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
|
||||
}
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
@@ -266,28 +302,22 @@ public:
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(
|
||||
IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B)
|
||||
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
|
||||
{
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
//iterator_B.add_tile_offset({1, 0});
|
||||
tile_dequanter_B.AddTileOffset({1, 0});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
// Advance shared iterators
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
//smem_iterator_B_.add_tile_offset({1, 0});
|
||||
smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
@@ -295,7 +325,7 @@ public:
|
||||
if (smem_write_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
@@ -338,9 +368,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
|
||||
if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
@@ -360,13 +395,14 @@ public:
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, is_valid);
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, gmem_ptr, is_valid);
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
@@ -375,7 +411,6 @@ public:
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@@ -399,8 +434,6 @@ public:
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
@@ -411,9 +444,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB, bool InitStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
@@ -433,35 +469,23 @@ public:
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
if (InitStage) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
} else {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B,
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
@@ -476,11 +500,18 @@ public:
|
||||
copy_tiles_and_advance_per_stage_A(iterator_A);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
copy_tiles_and_advance_per_stage_B(iterator_B);
|
||||
|
||||
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
|
||||
if (stage == 0) {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
|
||||
} else {
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
}
|
||||
|
||||
// Move to the next write stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_write_stage(iterator_A, iterator_B);
|
||||
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
@@ -510,6 +541,10 @@ public:
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
|
||||
return;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
@@ -542,57 +577,57 @@ public:
|
||||
}
|
||||
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
|
||||
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
const int mma_stage = stage - Base::kStages + 1;
|
||||
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
|
||||
|
||||
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
|
||||
|
||||
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
// load next-tile of group-wise local_scale from shared memory
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
// Unpack and dequant the first stage of B.
|
||||
int unpack_stage = stage - Base::kStages + 2;
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, unpack_stage);
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
|
||||
}
|
||||
// dequantizes next warp-tile
|
||||
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
|
||||
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
|
||||
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if (Detail::kStagedAccumulation) {
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
|
||||
@@ -604,22 +639,22 @@ public:
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
pipe_state.warp_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_frag_B_[warp_mma_k % 2],
|
||||
accum);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
// global->shared fragment copies
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -628,9 +663,15 @@ public:
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
}
|
||||
|
||||
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
|
||||
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
|
||||
}
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
@@ -639,69 +680,66 @@ public:
|
||||
gmem_wait();
|
||||
|
||||
// Move to the next global fetch stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_write_stage(iterator_A, iterator_B);
|
||||
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
|
||||
|
||||
advance_smem_read_stage();
|
||||
int byte_offset = quant_params_accessor_B_.advance_smem_read_stage();
|
||||
warp_dequantizer_.add_pointer_offset(byte_offset);
|
||||
|
||||
// Disable global fetching when done with global fetch iterations
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
}
|
||||
|
||||
// The last warp-tile also converts the shared memory fragments used by
|
||||
// the first warp-tile of the next iteration, if necessary (so we can
|
||||
// immediately start issuing MMA instructions at the top of the loop )
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the specified number of threadblock mainloop iterations of matrix
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
QuantArguments &mma_quant_args)
|
||||
{
|
||||
PipeState pipe_state;
|
||||
|
||||
// Unpack and dequant the first stage of B.
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
|
||||
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Transform, if necessary, the first warp-tile's shared memory fragments
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[0],
|
||||
pipe_state.warp_transformed_frag_B_[0],
|
||||
pipe_state.warp_loaded_frag_A_[0],
|
||||
pipe_state.warp_loaded_frag_B_[0]);
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_);
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Dequantize B to in register
|
||||
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
|
||||
pipe_state.warp_frag_code_scale_,
|
||||
pipe_state.warp_frag_code_zp_,
|
||||
pipe_state.warp_frag_super_scale_,
|
||||
pipe_state.warp_loaded_frag_B_,
|
||||
pipe_state.warp_frag_B_[0],
|
||||
0,
|
||||
0);
|
||||
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
|
||||
@@ -715,13 +753,13 @@ public:
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
mma_quant_args,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
if constexpr (Detail::kStagedAccumulation) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
@@ -761,14 +799,12 @@ public:
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
@@ -779,13 +815,13 @@ public:
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< pre-load and dequantize B to shared memory
|
||||
TileDequanterB tile_dequanter_B,
|
||||
///< iterators for extra quant params for B
|
||||
QuantArguments mma_quant_args,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
||||
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
@@ -794,7 +830,7 @@ public:
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -0,0 +1,315 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
/// Original data type
|
||||
typename T,
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterators over super scales in global memory
|
||||
typename IteratorSuperScale_,
|
||||
/// Iterators over super scales in shared memory
|
||||
typename SmemIteratorSuperScale_,
|
||||
/// Iterators over local scales in global memory
|
||||
typename IteratorLocalScale_,
|
||||
/// Iterators over local scales in shared memory
|
||||
typename SmemIteratorLocalScale_,
|
||||
/// Iterators over code scales and zps in global memory
|
||||
typename IteratorCodeScaleZp_,
|
||||
/// Iterators over code scales and zps in shared memory
|
||||
typename SmemIteratorCodeScaleZp_,
|
||||
/// Number of stages,
|
||||
int Stages_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_>
|
||||
class Wint2ParamsAccessor {
|
||||
public:
|
||||
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
using ElementType = T;
|
||||
using Shape = Shape_;
|
||||
|
||||
using IteratorSuperScale = IteratorSuperScale_;
|
||||
using SmemIteratorSuperScale = SmemIteratorSuperScale_;
|
||||
|
||||
using IteratorLocalScale = IteratorLocalScale_;
|
||||
using SmemIteratorLocalScale = SmemIteratorLocalScale_;
|
||||
|
||||
using IteratorCodeScaleZp = IteratorCodeScaleZp_;
|
||||
using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_;
|
||||
|
||||
constexpr static int kStages = Stages_;
|
||||
constexpr static int kGroupSize = GroupSize_;
|
||||
|
||||
using ElementSuperScale = typename IteratorSuperScale::Element;
|
||||
using LayoutSuperScale = typename IteratorSuperScale::Layout;
|
||||
|
||||
/// local_scale uint4 and group-wise
|
||||
using ElementLocalScale = typename IteratorLocalScale::Element;
|
||||
using LayoutLocalScale = typename IteratorLocalScale::Layout;
|
||||
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
|
||||
"local_scale's type must be uint4b_t.");
|
||||
|
||||
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
|
||||
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
|
||||
|
||||
/// 2 uint4b_t values are stored in a single uint8_t
|
||||
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
|
||||
constexpr static int kLocalScaleRows =
|
||||
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
|
||||
|
||||
using SmemElement = uint8_t;
|
||||
constexpr static int kSmemRows =
|
||||
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
|
||||
constexpr static int kSmemColumns = Shape::kN;
|
||||
|
||||
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
|
||||
|
||||
constexpr static int kSuperScaleSmemOffset = 0;
|
||||
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
|
||||
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
|
||||
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
|
||||
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
|
||||
|
||||
struct Arguments {
|
||||
IteratorSuperScale iterator_super_scale;
|
||||
IteratorLocalScale iterator_local_scale;
|
||||
IteratorCodeScaleZp iterator_code_scale;
|
||||
IteratorCodeScaleZp iterator_code_zp;
|
||||
|
||||
int local_scale_pointer_offset;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Arguments(IteratorSuperScale iterator_super_scale,
|
||||
IteratorLocalScale iterator_local_scale,
|
||||
IteratorCodeScaleZp iterator_code_scale,
|
||||
IteratorCodeScaleZp iterator_code_zp,
|
||||
int local_scale_pointer_offset)
|
||||
: iterator_super_scale(iterator_super_scale),
|
||||
iterator_local_scale(iterator_local_scale),
|
||||
iterator_code_scale(iterator_code_scale),
|
||||
iterator_code_zp(iterator_code_zp),
|
||||
local_scale_pointer_offset(local_scale_pointer_offset) {}
|
||||
};
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Begin address of shared memory
|
||||
uint8_t* smem_pointer_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
|
||||
SmemIteratorSuperScale smem_iterator_super_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
|
||||
SmemIteratorLocalScale smem_iterator_local_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
|
||||
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
|
||||
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementSuperScale* get_super_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementLocalScale* get_local_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_scale_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ElementCodeScaleZp* get_code_zp_smem_ptr() {
|
||||
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2ParamsAccessor(
|
||||
///< prointer of shared memory
|
||||
uint8_t* smem_pointer,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: smem_pointer_(smem_pointer),
|
||||
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
|
||||
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
|
||||
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
|
||||
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SuperTensorRef super_scale_ref() {
|
||||
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
LocalTensorRef local_scale_ref() {
|
||||
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_scale_ref() {
|
||||
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
CodeTensorRef code_zp_ref() {
|
||||
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
|
||||
}
|
||||
|
||||
template <bool IsFirstStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
|
||||
if constexpr (IsFirstStage) {
|
||||
// Load channel-wise super_scale to shared memory, which only needs to be done once.
|
||||
typename IteratorSuperScale::Fragment tb_frag_super_scale;
|
||||
tb_frag_super_scale.clear();
|
||||
quant_args.iterator_super_scale.load(tb_frag_super_scale);
|
||||
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
|
||||
|
||||
// Load channel-wise code_scale to shared memory, which only needs to be done once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
|
||||
tb_frag_code_scale.clear();
|
||||
quant_args.iterator_code_scale.load(tb_frag_code_scale);
|
||||
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
|
||||
|
||||
// Load channel-wise code_zp to shared memory, which only needs to be done once.
|
||||
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
|
||||
tb_frag_code_zp.clear();
|
||||
quant_args.iterator_code_zp.load(tb_frag_code_zp);
|
||||
this->smem_iterator_code_zp_.store(tb_frag_code_zp);
|
||||
}
|
||||
|
||||
if ((stage % kStagesPerLocalScaleLoad) == 0) {
|
||||
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
|
||||
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
|
||||
using AccessType = typename IteratorLocalScale::AccessType;
|
||||
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
|
||||
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
|
||||
|
||||
quant_args.iterator_local_scale.set_iteration_index(0);
|
||||
this->smem_iterator_local_scale_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for local_scale
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
|
||||
AccessType *dst_ptr =
|
||||
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = quant_args.iterator_local_scale.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorLocalScale::Element>::value *
|
||||
IteratorLocalScale::ThreadMap::kElementsPerAccess /
|
||||
IteratorLocalScale::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
|
||||
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
|
||||
}
|
||||
++quant_args.iterator_local_scale;
|
||||
}
|
||||
++this->smem_iterator_local_scale_;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(Arguments &quant_args) {
|
||||
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
|
||||
// Advance global iterators
|
||||
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
|
||||
|
||||
// Advance shared iterators
|
||||
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
|
||||
}
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
|
||||
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
|
||||
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
int advance_smem_read_stage() {
|
||||
int byte_offset = 0;
|
||||
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
|
||||
byte_offset = kLocalScaleRows * kSmemColumns;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
|
||||
smem_read_stage_idx_ = 0;
|
||||
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
|
||||
}
|
||||
|
||||
return byte_offset;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
int clear_mask(Arguments &quant_args, bool cond) {
|
||||
quant_args.iterator_local_scale.clear_mask(cond);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -1,130 +0,0 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
|
||||
int Stages, int NumThreads, WintQuantMethod Method>
|
||||
struct TileDequanter {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
|
||||
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
using UnzipAndDequantFunctor =
|
||||
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
|
||||
|
||||
static constexpr bool kUseSharedMemory = true;
|
||||
|
||||
static constexpr int kRows = Rows;
|
||||
static constexpr int kColumns = Columns;
|
||||
static constexpr int kStages = Stages;
|
||||
|
||||
MmaElementT *out_smem_ptr{nullptr};
|
||||
|
||||
char *pointer{nullptr};
|
||||
int64_t ldm{0};
|
||||
cutlass::MatrixCoord tb_offset;
|
||||
cutlass::MatrixCoord extent;
|
||||
|
||||
ScaleElementT *super_scale_ptr{nullptr};
|
||||
cutlass::MatrixCoord tb_offset_scale;
|
||||
|
||||
QuantArguments quant_args;
|
||||
|
||||
int64_t block_start_rows[kStages];
|
||||
bool need_preload{true};
|
||||
UnzipAndDequantFunctor unzip_functor;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
|
||||
const cutlass::MatrixCoord &extent,
|
||||
const cutlass::MatrixCoord &tb_offset,
|
||||
ScaleElementT *super_scale_ptr,
|
||||
const cutlass::MatrixCoord &tb_offset_scale,
|
||||
const QuantArguments &quant_args)
|
||||
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
|
||||
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
|
||||
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaElementT *GetOutPtr() { return out_smem_ptr; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
|
||||
tb_offset.row() += tile_offset.row() * kRows;
|
||||
tb_offset.column() += tile_offset.column() * kColumns;
|
||||
tb_offset_scale.column() += tile_offset.column() * kColumns;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
|
||||
if (tb_offset.row() >= extent.row() ||
|
||||
tb_offset.column() >= extent.column()) {
|
||||
return;
|
||||
}
|
||||
|
||||
block_start_rows[stage % kStages] = tb_offset.row();
|
||||
|
||||
using ZippedT = typename WeightQuantTraits::WeightType;
|
||||
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
|
||||
tb_offset.column();
|
||||
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
|
||||
(tb_offset.row() / 128) * ldm +
|
||||
tb_offset_scale.column();
|
||||
const float *code_scale_ptr =
|
||||
quant_args.code_scale_ptr + tb_offset_scale.column();
|
||||
const float *code_zp_ptr =
|
||||
quant_args.code_zp_ptr + tb_offset_scale.column();
|
||||
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
|
||||
scale_ptr, &args, ldm, need_preload);
|
||||
need_preload = false;
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int64_t block_start_row = block_start_rows[stage % kStages];
|
||||
if (block_start_row >= extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
@@ -41,12 +41,9 @@
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -81,7 +78,7 @@ private:
|
||||
// Shape for computing the FP16s
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
|
@@ -58,15 +58,12 @@
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
@@ -297,6 +294,235 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
|
||||
/// Specialization for B of uint2b_t.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
typename SharedMemoryInstructionShape_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
class MmaTensorOpComputeBWithF16<
|
||||
Shape_,
|
||||
ElementA_,
|
||||
LayoutA_,
|
||||
uint2b_t,
|
||||
LayoutB_,
|
||||
ElementC_,
|
||||
LayoutC_,
|
||||
Policy_,
|
||||
SharedMemoryInstructionShape_,
|
||||
PartitionsK_,
|
||||
AccumulatorsInRowMajor>
|
||||
{
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = uint2b_t;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
||||
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
|
||||
|
||||
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
||||
|
||||
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA
|
||||
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
|
||||
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
|
||||
kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB =
|
||||
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / kExpansionFactor>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
||||
|
||||
public:
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpComputeBWithF16() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
||||
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
||||
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
// Serpentine visitation order maximizing reuse of Rb
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
||||
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n],
|
||||
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n],
|
||||
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
// Serpentine visitation order maximizing reuse of Ra
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
|
@@ -0,0 +1,442 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations
|
||||
targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace warp {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<bfloat16_t> {
|
||||
using Type = __nv_bfloat16;
|
||||
using DualType = __nv_bfloat162;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<half_t> {
|
||||
using Type = __half;
|
||||
using DualType = __half2;
|
||||
};
|
||||
|
||||
template <typename T, int N, typename Enable = void>
|
||||
struct LocalScaleConverter {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<T, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
constexpr uint32_t kLocalScaleMask = 0xf;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_value = (static_cast<int32_t>(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask;
|
||||
scale_frag[i] = static_cast<T>(shifted_value) * super_scale_frag[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct LocalScaleConverter<half_t, N, typename platform::enable_if<N % 4 == 0>::type> {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<half_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
constexpr uint32_t MASK = 0x000f000f;
|
||||
// 2^10 = 1024
|
||||
constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// -2^10 = -1024
|
||||
constexpr uint32_t FP16_BIAS = 0xE400E400;
|
||||
// 1.0
|
||||
constexpr uint32_t FP16_ONE = 0x3C003C00;
|
||||
|
||||
__half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag);
|
||||
__half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag);
|
||||
|
||||
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
int i4s = local_scale_ptr[i] >> shift_bit;
|
||||
|
||||
// unpack: 0, 1
|
||||
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
|
||||
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_FP16s_MAGIC_NUM);
|
||||
// unpack: 2, 3
|
||||
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
|
||||
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_FP16s_MAGIC_NUM);
|
||||
|
||||
__half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0),
|
||||
*reinterpret_cast<const __half2*>(&FP16_ONE),
|
||||
*reinterpret_cast<const __half2*>(&FP16_BIAS));
|
||||
__half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1),
|
||||
*reinterpret_cast<const __half2*>(&FP16_ONE),
|
||||
*reinterpret_cast<const __half2*>(&FP16_BIAS));
|
||||
|
||||
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
|
||||
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct LocalScaleConverter<bfloat16_t, N, typename platform::enable_if<N % 4 == 0>::type> {
|
||||
using FragmentSource = Array<uint8_t, N>;
|
||||
using FragmentResult = Array<bfloat16_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void Apply(FragmentSource const& local_scale_frag,
|
||||
FragmentResult const& super_scale_frag,
|
||||
FragmentResult& scale_frag,
|
||||
int shift_bit) {
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
constexpr uint32_t MASK = 0x000F000F;
|
||||
constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
constexpr uint32_t BF16_BIAS = 0xC300C300;
|
||||
constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
|
||||
__nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag);
|
||||
__nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag);
|
||||
|
||||
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 4; ++i) {
|
||||
int i4s = local_scale_ptr[i] >> shift_bit;
|
||||
|
||||
// unpack: 0, 1
|
||||
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
|
||||
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_BF16s_MAGIC_NUM);
|
||||
// unpack: 2, 3
|
||||
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
|
||||
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_BF16s_MAGIC_NUM);
|
||||
|
||||
nv_bfloat162 scale0 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
|
||||
nv_bfloat162 scale1 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
|
||||
|
||||
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
|
||||
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Matrix multiply operator
|
||||
typename MmaOperator_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand,
|
||||
/// Data type of Scale elements
|
||||
typename ElementOperand_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_,
|
||||
///
|
||||
typename Enable = void>
|
||||
class MmaTensorOpWin2xDequantizer {
|
||||
//static_assert(false, "Not Supported!");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat specialization for Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
/// Data type of Scale elements
|
||||
typename ElementOperand_,
|
||||
/// Group size for quantization
|
||||
int GroupSize_>
|
||||
class MmaTensorOpWin2xDequantizer<
|
||||
MmaOperator_,
|
||||
Shape_,
|
||||
Operand::kB,
|
||||
ElementOperand_,
|
||||
layout::RowMajor,
|
||||
GroupSize_>
|
||||
//typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
||||
// && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
public:
|
||||
static_assert(platform::is_same<ElementOperand_, half_t>::value || platform::is_same<ElementOperand_, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Type of mma operand
|
||||
using ElementOperand = ElementOperand_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// Group size for quantization
|
||||
static constexpr int kGroupSize = GroupSize_;
|
||||
|
||||
/// Type of input
|
||||
using ElementB = typename MmaOperator::FragmentB::Element;
|
||||
static_assert(platform::is_same<ElementB, uint2b_t>::value, "ElementB must be uint2b_t");
|
||||
|
||||
/// Type of the scales
|
||||
using ElementLocalScale = uint4b_t;
|
||||
using ElementSuperScale = ElementOperand;
|
||||
using ElementCodeScaleZp = float;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn;
|
||||
|
||||
// use uint8_t to save 2 4-bits local scales
|
||||
using FragmentLocalScale = Array<uint8_t, kWarpIterationsAlongN>;
|
||||
using FragmentSuperScale = Array<ElementSuperScale, kWarpIterationsAlongN>;
|
||||
using FragmentCodeScaleZp = Array<ElementCodeScaleZp, kWarpIterationsAlongN>;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentInput = Array<ElementB, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
static constexpr int kNumPacks = sizeof_bits<uint8_t>::value / sizeof_bits<ElementB>::value;
|
||||
static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks);
|
||||
static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor;
|
||||
|
||||
/// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points.
|
||||
using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter<
|
||||
ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>;
|
||||
using FragmentInputUnpack = typename Uint2Converter::result_type;
|
||||
|
||||
/// Fragment to hold internal scales before Mma
|
||||
using FragmentScale = Array<ElementOperand, FragmentLocalScale::kElements>;
|
||||
|
||||
/// Fragment of dequantized B
|
||||
using FragmentOutput = Array<ElementOperand, MmaOperator::FragmentB::kElements / kExpansionFactor>;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, Layout>;
|
||||
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, Layout>;
|
||||
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, Layout>;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
uint8_t* pointer_local_scale_;
|
||||
ElementCodeScaleZp* pointer_code_scale_;
|
||||
ElementCodeScaleZp* pointer_code_zp_;
|
||||
ElementSuperScale* pointer_super_scale_;
|
||||
|
||||
//FragmentInputUnpack unpacked_frag_;
|
||||
FragmentScale scale_frag_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale,
|
||||
LocalTensorRef smem_local_scale,
|
||||
CodeTensorRef smem_code_scale,
|
||||
CodeTensorRef smem_code_zp,
|
||||
int warp_idx_n,
|
||||
int lane_idx) {
|
||||
int warp_offset = warp_idx_n * Shape::kN;
|
||||
int quad = lane_idx / 4;
|
||||
int thread_offset = warp_offset + quad;
|
||||
pointer_super_scale_ = smem_super_scale.data() + thread_offset;
|
||||
pointer_code_scale_ = smem_code_scale.data() + thread_offset;
|
||||
pointer_code_zp_ = smem_code_zp.data() + thread_offset;
|
||||
pointer_local_scale_ = reinterpret_cast<uint8_t *>(smem_local_scale.data()) + thread_offset;
|
||||
}
|
||||
|
||||
/// Channel-wise params, need to load just once
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentCodeScaleZp& code_scale_frag,
|
||||
FragmentCodeScaleZp& code_zp_frag,
|
||||
FragmentSuperScale& super_scale_frag) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
|
||||
code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN];
|
||||
code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
/// Group-wise params, need to load multiple times
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentLocalScale& local_scale_frag) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(const FragmentLocalScale& local_scale_frag,
|
||||
const FragmentCodeScaleZp& code_scale_frag,
|
||||
const FragmentCodeScaleZp& code_zp_frag,
|
||||
const FragmentSuperScale& super_scale_frag,
|
||||
const FragmentInput& input_frag,
|
||||
FragmentOutput& output_frag,
|
||||
int tb_offset_k,
|
||||
int warp_k_compute_offset) {
|
||||
if constexpr (kUnpackInterval != 1) {
|
||||
// unsupport now
|
||||
arch::device_breakpoint();
|
||||
}
|
||||
|
||||
typename Uint2Converter::source_type source_frag;
|
||||
|
||||
int in_offset = warp_k_compute_offset * kUnpackInterval;
|
||||
|
||||
uint8_t const* ptr_input = reinterpret_cast<uint8_t const*>(&input_frag);
|
||||
uint8_t* ptr_source = reinterpret_cast<uint8_t *>(&source_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
|
||||
ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset];
|
||||
}
|
||||
FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag);
|
||||
|
||||
// dequantize local_scale
|
||||
if (warp_k_compute_offset == 0) {
|
||||
using LocalScaleConverter = detail::LocalScaleConverter<ElementOperand, FragmentLocalScale::kElements>;
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4;
|
||||
LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift);
|
||||
}
|
||||
|
||||
// unscale
|
||||
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
|
||||
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
|
||||
const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN;
|
||||
|
||||
using Type = typename detail::DataTypeTraits<ElementOperand>::Type;
|
||||
using DualType = typename detail::DataTypeTraits<ElementOperand>::DualType;
|
||||
|
||||
Type* output_ptr = reinterpret_cast<Type *>(&output_frag);
|
||||
DualType const* unpacked_ptr = reinterpret_cast<DualType const*>(&unpacked_frag);
|
||||
DualType const* scale_ptr = reinterpret_cast<DualType const*>(&scale_frag_);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) {
|
||||
int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK;
|
||||
|
||||
DualType scalex2 = scale_ptr[mma_n_iter / 2];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) {
|
||||
DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter];
|
||||
DualType scaled_value = __hmul2(unpacked_valuex2, scalex2);
|
||||
output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x;
|
||||
output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an offset to pointer in units of elements.
|
||||
/// Only group-wise params needs.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset) {
|
||||
pointer_local_scale_ += offset;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@@ -39,18 +39,25 @@
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace cutlass {
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
||||
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
||||
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
||||
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
||||
template <typename T, typename S, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter
|
||||
{
|
||||
};
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter;
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
|
||||
@@ -440,6 +447,329 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
|
||||
{
|
||||
using result_type = Array<half_t, 16>;
|
||||
using source_type = Array<uint2b_t, 16>;
|
||||
|
||||
using ScaleComputeT = float;
|
||||
using code_type = Array<ScaleComputeT, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
ScaleComputeT new_code_zp = code_zp + 0.5f;
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert_impl(int32_t* decode_value)
|
||||
{
|
||||
result_type result;
|
||||
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
|
||||
static constexpr uint32_t MASK = 0x003F003F;
|
||||
// 2^10 = 1024
|
||||
static constexpr uint32_t EX = 0x64006400;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
|
||||
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
|
||||
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
|
||||
|
||||
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
|
||||
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
|
||||
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
|
||||
h[3] = lop3<immLut>(q0, MASK, EX);
|
||||
|
||||
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
|
||||
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
|
||||
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
|
||||
h[7] = lop3<immLut>(q1, MASK, EX);
|
||||
|
||||
// 1024 + 32 = 1056
|
||||
static constexpr uint32_t SUB = 0x64206420;
|
||||
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
|
||||
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 16>;
|
||||
using source_type = Array<uint2b_t, 16>;
|
||||
|
||||
using ScaleComputeT = float;
|
||||
using code_type = Array<ScaleComputeT, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
ScaleComputeT new_code_zp = code_zp + 0.5f;
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// 2^23 = 8388608
|
||||
static constexpr uint32_t FP32_BASE = 0x4B000000;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
|
||||
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
|
||||
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
|
||||
|
||||
int32_t decode_value[4];
|
||||
|
||||
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
|
||||
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
|
||||
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
|
||||
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
|
||||
|
||||
return convert_impl(decode_value);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert_impl(int32_t* decode_value)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
|
||||
static constexpr uint32_t MASK = 0x003F003F;
|
||||
// 2^7 = 128
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
|
||||
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
|
||||
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
|
||||
|
||||
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
|
||||
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
|
||||
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
|
||||
h[3] = lop3<immLut>(q0, MASK, EX);
|
||||
|
||||
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
|
||||
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
|
||||
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
|
||||
h[7] = lop3<immLut>(q1, MASK, EX);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
|
||||
// 128 + 32 = 160
|
||||
static constexpr uint32_t SUB = 0x43204320;
|
||||
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
|
||||
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
|
||||
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
|
||||
#else
|
||||
// 1.0
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
// -160
|
||||
static constexpr uint32_t ADD = 0xC320C320;
|
||||
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
|
||||
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
|
||||
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
|
||||
{
|
||||
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
|
||||
"T must be fp16 or bf16");
|
||||
|
||||
static constexpr int kVecWidth = 16;
|
||||
static_assert(!(N % kVecWidth), "N must be multiple of 16.");
|
||||
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<uint2b_t, N>;
|
||||
using code_type = Array<float, N / kVecWidth>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, kVecWidth>;
|
||||
using vec_source = Array<scalar_source_type, kVecWidth>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / kVecWidth; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
|
||||
|
||||
result_type result;
|
||||
using vec_result = typename Converter::result_type;
|
||||
using vec_source = typename Converter::source_type;
|
||||
using vec_code = typename Converter::code_type;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
|
||||
vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / kVecWidth; ++i)
|
||||
{
|
||||
result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
|
||||
{
|
||||
return convert(s, code_scale, code_zp);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
@@ -125,10 +125,13 @@ struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
|
||||
static constexpr int32_t kNumPackedValues = 4;
|
||||
static constexpr int32_t kPackedSize = 16;
|
||||
|
||||
using LocalScaleType = uint4b_t;
|
||||
using CodeScaleZpType = float;
|
||||
|
||||
struct Arguments {
|
||||
const uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
const float *code_scale_ptr;
|
||||
const float *code_zp_ptr;
|
||||
uint8_t *local_scale_ptr; // quanted 4-bits
|
||||
float *code_scale_ptr;
|
||||
float *code_zp_ptr;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE
|
||||
|
@@ -43,7 +43,6 @@
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -775,17 +774,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
template <WintQuantMethod QuantMethod, typename dummy>
|
||||
struct KernelRunner<QuantMethod, true, dummy> {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
|
||||
QuantArguments quant_args;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
|
||||
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
}
|
||||
return quant_args;
|
||||
static MmaQuantArguments prepare_quant_args(
|
||||
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
|
||||
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
|
||||
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
|
||||
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
|
||||
weight_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
|
||||
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
|
||||
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
|
||||
|
||||
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
|
||||
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
|
||||
local_scale_ptr,
|
||||
{(gemm_k + 127) / 128, gemm_n * 2},
|
||||
thread_idx,
|
||||
tb_offset_local_scale);
|
||||
|
||||
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_scale_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
|
||||
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
|
||||
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
|
||||
code_zp_ptr,
|
||||
{1, gemm_n},
|
||||
thread_idx,
|
||||
tb_offset_scale);
|
||||
|
||||
MmaQuantArguments mma_quant_args(
|
||||
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
|
||||
return mma_quant_args;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
@@ -814,9 +850,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// LayoutB should be RowMajor
|
||||
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
@@ -843,12 +876,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
|
||||
0);
|
||||
|
||||
// begin address offset for weight_scale.
|
||||
ElementScale* weight_scale_ptr =
|
||||
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
|
||||
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on
|
||||
// the transpose
|
||||
int64_t rows_to_jump = 0;
|
||||
@@ -866,42 +893,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
// the begin threadblock_offset of A, which holds the same row id with C
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
|
||||
|
||||
// begin address offset for B for current problem_idx, totally num_experts problems
|
||||
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
|
||||
problem_idx * bytes_per_expert_matrix; // NOLINT
|
||||
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B =
|
||||
platform::is_same<layout::RowMajor, LayoutB>::value
|
||||
? gemm_n
|
||||
: gemm_k * kInterleave;
|
||||
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
|
||||
|
||||
// the begin threadblock_offset of B, which holds the same column id with C
|
||||
cutlass::MatrixCoord tb_offset_B{0,
|
||||
threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
|
||||
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
|
||||
|
||||
MmaElementB* smem_unzip_B_ptr = nullptr;
|
||||
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
|
||||
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
|
||||
}
|
||||
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
|
||||
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
|
||||
byte_ptr_B,
|
||||
ldm_B,
|
||||
extent_B,
|
||||
tb_offset_B,
|
||||
weight_scale_ptr,
|
||||
tb_offset_scale,
|
||||
quant_args);
|
||||
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
@@ -914,20 +919,21 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
|
||||
LayoutB(ldm_B),
|
||||
ptr_B,
|
||||
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
|
||||
extent_B,
|
||||
thread_idx,
|
||||
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
|
||||
tb_offset_B);
|
||||
|
||||
MmaQuantArguments mma_quant_args = prepare_quant_args(
|
||||
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
@@ -950,7 +956,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
|
||||
accumulators,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
mma_quant_args,
|
||||
accumulators);
|
||||
|
||||
//
|
||||
|
@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
|
||||
threadblock_count,
|
||||
epilogue_op,
|
||||
reinterpret_cast<const ElementType*>(A),
|
||||
reinterpret_cast<const CutlassMmaWeightType*>(B),
|
||||
reinterpret_cast<const CutlassMmaKernelType*>(B),
|
||||
reinterpret_cast<const ElementType*>(weight_scales),
|
||||
reinterpret_cast<const ElementType*>(biases),
|
||||
reinterpret_cast<ElementType*>(C),
|
||||
|
@@ -49,12 +49,13 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
|
||||
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
|
||||
typename WeightOnlyTraits::Arguments down_proj_quant_args;
|
||||
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
|
||||
up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data<uint8_t>();
|
||||
up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data<float>();
|
||||
up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data<float>();
|
||||
down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data<uint8_t>();
|
||||
down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data<float>();
|
||||
down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data<float>();
|
||||
up_gate_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(up_gate_proj_local_scale->data<uint8_t>());
|
||||
up_gate_proj_quant_args.code_scale_ptr = const_cast<float*>(up_gate_proj_code_scale->data<float>());
|
||||
up_gate_proj_quant_args.code_zp_ptr = const_cast<float*>(up_gate_proj_code_zp->data<float>());
|
||||
|
||||
down_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(down_proj_local_scale->data<uint8_t>());
|
||||
down_proj_quant_args.code_scale_ptr = const_cast<float*>(down_proj_code_scale->data<float>());
|
||||
down_proj_quant_args.code_zp_ptr = const_cast<float*>(down_proj_code_zp->data<float>());
|
||||
}
|
||||
|
||||
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
|
||||
|
Reference in New Issue
Block a user