mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -77,6 +77,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -125,6 +126,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -148,7 +150,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -179,6 +181,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -234,6 +237,7 @@ public:
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
@@ -346,6 +350,131 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, El
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
|
||||
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -19,13 +19,11 @@
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -197,6 +195,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -244,6 +243,9 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
@@ -265,7 +267,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
@@ -296,6 +298,7 @@ public:
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
@@ -318,11 +321,11 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
@@ -348,6 +351,131 @@ public:
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB =
|
||||
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
|
||||
false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaBase {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount =
|
||||
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
|
||||
Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM oeprations
|
||||
static int const kWarpGemmIterations =
|
||||
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
|
||||
|
||||
static_assert(kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
static_assert((kWarpGemmIterations % 2) == 0,
|
||||
"Inner loop iteration must be an even number.");
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage {
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA =
|
||||
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
|
||||
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
|
||||
Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
// w uint8; local_scale uint8;
|
||||
constexpr static int kZippedRowsPerStages =
|
||||
Shape::kK / 4 + (Shape::kK + 127) / 128;
|
||||
|
||||
// code_scale float; code_zp float; super_scale ElementB
|
||||
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
|
||||
sizeof_bits<typename Operator::ElementB>::value / 8;
|
||||
|
||||
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
|
||||
|
||||
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer for quanted B operand
|
||||
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
|
||||
|
||||
/// Buffer for unzip B operand
|
||||
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
|
||||
operand_unzip_B;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA() {
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB() {
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref() {
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref() {
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
typename Operator::ElementB *operand_unzip_B_ptr() {
|
||||
return operand_unzip_B.data();
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
|
||||
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,807 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/arch/memory_sm80.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/arch/memory_copy_sm80.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class Wint2xMmaMultistage :
|
||||
public Wint2xMmaBase<Shape_, Policy_, Stages> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail {
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA =
|
||||
IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB =
|
||||
IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA =
|
||||
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB =
|
||||
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
|
||||
// accuracy, where each mainloop iteration first accumulates into a temporary
|
||||
// set of freshly-cleared accumulators, which are subsequently added to the
|
||||
// final accumulator set.
|
||||
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
// Structure encapsulating pipeline state live from one iteration to the next
|
||||
struct PipeState {
|
||||
|
||||
using WarpLoadedFragmentA = typename Operator::FragmentA;
|
||||
using WarpLoadedFragmentB = typename Operator::FragmentB;
|
||||
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
|
||||
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
|
||||
|
||||
/// Temporary accumulator to facilitate staged-accumulation
|
||||
FragmentC tmp_accum_;
|
||||
|
||||
/// Pair of A fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentA warp_loaded_frag_A_[2];
|
||||
WarpTransformedFragmentA warp_transformed_frag_A_[2];
|
||||
|
||||
/// Pair of B fragments used to overlap shared memory loads and math instructions
|
||||
WarpLoadedFragmentB warp_loaded_frag_B_[2];
|
||||
WarpTransformedFragmentB warp_transformed_frag_B_[2];
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Warp-level MMA operator
|
||||
Operator warp_mma_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Shared memory write stage index
|
||||
int smem_write_stage_idx_;
|
||||
|
||||
/// Shared memory read stage index
|
||||
int smem_read_stage_idx_;
|
||||
|
||||
uint8_t* column_wise_smem_ptr_B_;
|
||||
|
||||
uint8_t* smem_zipped_ptr_B_;
|
||||
int smem_zipped_bytes_per_stage_B_;
|
||||
|
||||
public:
|
||||
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
Wint2xMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage &shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx
|
||||
):
|
||||
Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
|
||||
smem_write_stage_idx_(0),
|
||||
smem_read_stage_idx_(0)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
|
||||
|
||||
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
|
||||
|
||||
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
|
||||
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
|
||||
}
|
||||
|
||||
/// Advance shared memory read-iterators to the next stage
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_read_stage()
|
||||
{
|
||||
++smem_read_stage_idx_;
|
||||
|
||||
if (smem_read_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
smem_read_stage_idx_ = 0;
|
||||
}
|
||||
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
|
||||
}
|
||||
|
||||
/// Advance global memory read-iterators and shared memory write-iterators to the stage
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void advance_smem_write_stage(
|
||||
IteratorA &iterator_A,
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B)
|
||||
{
|
||||
// Advance global iterators
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
//iterator_B.add_tile_offset({1, 0});
|
||||
tile_dequanter_B.AddTileOffset({1, 0});
|
||||
|
||||
// Advance shared iterators
|
||||
smem_iterator_A_.add_tile_offset({0, 1});
|
||||
//smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Increment shared memory write stage index
|
||||
++smem_write_stage_idx_;
|
||||
|
||||
if (smem_write_stage_idx_ == Base::kStages) {
|
||||
// Wrap back around to the 'start' of the circular buffer in shared memory
|
||||
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_A(IteratorA &iterator_A, int group_start_A = 0) {
|
||||
iterator_A.set_iteration_index(group_start_A *
|
||||
IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
} else {
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
|
||||
iterator_B.set_iteration_index(group_start_B *
|
||||
IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_A(IteratorA &iterator_A) {
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorA::Element>::value *
|
||||
IteratorA::ThreadMap::kElementsPerAccess /
|
||||
IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool GlobalToSharedB, bool InitStage>
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
int const kSrcBytes =
|
||||
sizeof_bits<typename IteratorB::Element>::value *
|
||||
IteratorB::ThreadMap::kElementsPerAccess /
|
||||
IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
if (InitStage) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
} else {
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
} else {
|
||||
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
|
||||
dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
|
||||
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void prologue(
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B,
|
||||
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
{
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
copy_tiles_and_advance_per_stage_A(iterator_A);
|
||||
|
||||
// Async copy zipped B to shared memory.
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
|
||||
// Move to the next write stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
|
||||
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
typename IteratorA::AccessType zero_A;
|
||||
|
||||
zero_A.clear();
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
|
||||
|
||||
typename IteratorA::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorA::AccessType *>(
|
||||
last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
|
||||
|
||||
typename IteratorB::AccessType *dst_ptr =
|
||||
reinterpret_cast<typename IteratorB::AccessType *>(
|
||||
last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait until we have at least one completed global fetch stage
|
||||
CUTLASS_DEVICE
|
||||
void gmem_wait()
|
||||
{
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void mac_loop_iter(
|
||||
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
|
||||
FragmentC &accum, ///< [in|out] destination accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
|
||||
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
|
||||
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
|
||||
int stage)
|
||||
{
|
||||
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
|
||||
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
|
||||
|
||||
// Load the next warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
// Unpack and dequant the first stage of B.
|
||||
int unpack_stage = stage - Base::kStages + 2;
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, unpack_stage);
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
|
||||
}
|
||||
|
||||
// Load the next warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
|
||||
if (warp_mma_k > 0) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
|
||||
}
|
||||
|
||||
// Execute the current warp-tile of MMA operations
|
||||
if (Detail::kStagedAccumulation) {
|
||||
warp_mma_(
|
||||
pipe_state.tmp_accum_,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
pipe_state.tmp_accum_
|
||||
);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
} else {
|
||||
warp_mma_(
|
||||
accum,
|
||||
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
|
||||
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
|
||||
accum
|
||||
);
|
||||
}
|
||||
|
||||
// Except for the last warp-tile, all warp-tiles issue their share of
|
||||
// global->shared fragment copies
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
|
||||
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
if (warp_mma_k == 0) {
|
||||
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
|
||||
column_wise_smem_ptr_B_, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// The second-to-last warp-tile also:
|
||||
// - performs the last warp-tile's share of global->shared fragment copies
|
||||
// - moves to the next global fetch stage
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
|
||||
// Performs the last warp-tile's share of global->shared fragment copies
|
||||
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
|
||||
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Move to the next global fetch stage
|
||||
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
|
||||
advance_smem_read_stage();
|
||||
|
||||
// Disable global fetching when done with global fetch iterations
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
}
|
||||
|
||||
// The last warp-tile also converts the shared memory fragments used by
|
||||
// the first warp-tile of the next iteration, if necessary (so we can
|
||||
// immediately start issuing MMA instructions at the top of the loop )
|
||||
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
|
||||
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the specified number of threadblock mainloop iterations of matrix
|
||||
/// multiply-accumulate. Assumes prologue has been initiated.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void gemm_iters(
|
||||
int gemm_k_iterations, ///< number of threadblock mainloop iterations
|
||||
FragmentC &accum, ///< [in|out] accumulator tile
|
||||
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
|
||||
IteratorB &iterator_B,
|
||||
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
|
||||
{
|
||||
PipeState pipe_state;
|
||||
|
||||
// Unpack and dequant the first stage of B.
|
||||
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
|
||||
|
||||
// Disable global fetching if done with global fetch iterations
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
|
||||
|
||||
// Load first warp-tile's A fragment from shared memory
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
// Copy dequatized data to shared memory used by mma core.
|
||||
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
|
||||
|
||||
// Load first warp-tile's B fragment from shared memory
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
// Transform, if necessary, the first warp-tile's shared memory fragments
|
||||
warp_mma_.transform(
|
||||
pipe_state.warp_transformed_frag_A_[0],
|
||||
pipe_state.warp_transformed_frag_B_[0],
|
||||
pipe_state.warp_loaded_frag_A_[0],
|
||||
pipe_state.warp_loaded_frag_B_[0]);
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
pipe_state.tmp_accum_.clear();
|
||||
}
|
||||
|
||||
int stage = Base::kStages - 1;
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);) {
|
||||
mac_loop_iter(
|
||||
pipe_state,
|
||||
accum,
|
||||
iterator_A,
|
||||
iterator_B,
|
||||
tile_dequanter_B,
|
||||
gemm_k_iterations,
|
||||
stage);
|
||||
stage += 1;
|
||||
}
|
||||
|
||||
if (Detail::kStagedAccumulation) {
|
||||
plus<FragmentC> plus_accum;
|
||||
accum = plus_accum(accum, pipe_state.tmp_accum_);
|
||||
}
|
||||
|
||||
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
/// Prepares the class for another prologue.
|
||||
CUTLASS_DEVICE
|
||||
void wind_down()
|
||||
{
|
||||
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
|
||||
|
||||
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
|
||||
// just decrement one tile, but not all iterators implement --() decrement.)
|
||||
#pragma unroll
|
||||
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
smem_read_stage_idx_++;
|
||||
|
||||
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
|
||||
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
|
||||
if (smem_read_stage_idx_ > 1)
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
|
||||
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
|
||||
}
|
||||
smem_read_stage_idx_ = smem_write_stage_idx_;
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
|
||||
template <typename TileDequanterB>
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC &accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< pre-load and dequantize B to shared memory
|
||||
TileDequanterB tile_dequanter_B,
|
||||
///< initial value of accumulator
|
||||
FragmentC const &src_accum) {
|
||||
|
||||
// Prologue (start fetching iterations of global fragments into shared memory)
|
||||
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
||||
|
||||
// Wait until we have at least one completed global fetch stage
|
||||
gmem_wait();
|
||||
|
||||
// Initialize destination accumulators with source accumulators
|
||||
accum = src_accum;
|
||||
|
||||
// Perform the MAC-iterations
|
||||
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,130 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
|
||||
int Stages, int NumThreads, WintQuantMethod Method>
|
||||
struct TileDequanter {
|
||||
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
|
||||
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
|
||||
using QuantArguments = typename WeightQuantTraits::Arguments;
|
||||
|
||||
using UnzipAndDequantFunctor =
|
||||
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
|
||||
|
||||
static constexpr bool kUseSharedMemory = true;
|
||||
|
||||
static constexpr int kRows = Rows;
|
||||
static constexpr int kColumns = Columns;
|
||||
static constexpr int kStages = Stages;
|
||||
|
||||
MmaElementT *out_smem_ptr{nullptr};
|
||||
|
||||
char *pointer{nullptr};
|
||||
int64_t ldm{0};
|
||||
cutlass::MatrixCoord tb_offset;
|
||||
cutlass::MatrixCoord extent;
|
||||
|
||||
ScaleElementT *super_scale_ptr{nullptr};
|
||||
cutlass::MatrixCoord tb_offset_scale;
|
||||
|
||||
QuantArguments quant_args;
|
||||
|
||||
int64_t block_start_rows[kStages];
|
||||
bool need_preload{true};
|
||||
UnzipAndDequantFunctor unzip_functor;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
|
||||
const cutlass::MatrixCoord &extent,
|
||||
const cutlass::MatrixCoord &tb_offset,
|
||||
ScaleElementT *super_scale_ptr,
|
||||
const cutlass::MatrixCoord &tb_offset_scale,
|
||||
const QuantArguments &quant_args)
|
||||
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
|
||||
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
|
||||
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaElementT *GetOutPtr() { return out_smem_ptr; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
|
||||
tb_offset.row() += tile_offset.row() * kRows;
|
||||
tb_offset.column() += tile_offset.column() * kColumns;
|
||||
tb_offset_scale.column() += tile_offset.column() * kColumns;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
|
||||
if (tb_offset.row() >= extent.row() ||
|
||||
tb_offset.column() >= extent.column()) {
|
||||
return;
|
||||
}
|
||||
|
||||
block_start_rows[stage % kStages] = tb_offset.row();
|
||||
|
||||
using ZippedT = typename WeightQuantTraits::WeightType;
|
||||
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
|
||||
tb_offset.column();
|
||||
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
|
||||
(tb_offset.row() / 128) * ldm +
|
||||
tb_offset_scale.column();
|
||||
const float *code_scale_ptr =
|
||||
quant_args.code_scale_ptr + tb_offset_scale.column();
|
||||
const float *code_zp_ptr =
|
||||
quant_args.code_zp_ptr + tb_offset_scale.column();
|
||||
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
|
||||
scale_ptr, &args, ldm, need_preload);
|
||||
need_preload = false;
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
|
||||
int64_t block_start_row = block_start_rows[stage % kStages];
|
||||
if (block_start_row >= extent.row()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
|
||||
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
|
||||
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
|
||||
} else {
|
||||
// CUTLASS_TRACE_DEVICE("Not Supported!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,447 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass_extensions/wint_type_traits.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <typename T, int N>
|
||||
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
|
||||
|
||||
template <typename T, WintQuantMethod QuantMethod, int TileRows,
|
||||
int TileColumns, int NumThreads = 128>
|
||||
struct UnzipAndDequantFunctor {
|
||||
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint16_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kZippedGroupSize = 10;
|
||||
static constexpr int32_t kNumPackedValues = 7;
|
||||
|
||||
static constexpr int32_t kWeightMask = 0x7;
|
||||
static constexpr int32_t kLocalScaleMask = 0x1FFF;
|
||||
static constexpr int32_t kBZP = 4;
|
||||
|
||||
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
|
||||
ScaleComputeT scale) {
|
||||
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
|
||||
int32_t value = shifted_value - kBZP;
|
||||
|
||||
ScaleComputeT scaled_value = static_cast<ScaleComputeT>(value) * scale;
|
||||
return static_cast<T>(scaled_value);
|
||||
}
|
||||
|
||||
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
|
||||
T *out_ptr, const int64_t in_stride) {
|
||||
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(super_scale_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
// the last row in group
|
||||
int zipped_row_last = group_id * 10 + 9;
|
||||
int zipped_offset_last = zipped_row_last * in_stride + col;
|
||||
int32_t zipped_value_last =
|
||||
static_cast<int32_t>(in_ptr[zipped_offset_last]);
|
||||
|
||||
ScaleComputeT local_scale =
|
||||
static_cast<ScaleComputeT>(zipped_value_last & kLocalScaleMask);
|
||||
ScaleComputeT scale = local_scale * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row_in_group = 0; zipped_row_in_group < 9;
|
||||
++zipped_row_in_group) {
|
||||
int zipped_row = group_id * 10 + zipped_row_in_group;
|
||||
int zipped_offset = zipped_row * in_stride + col;
|
||||
int32_t zipped_value = static_cast<int32_t>(in_ptr[zipped_offset]);
|
||||
|
||||
int row_in_group = group_id * 64 + zipped_row_in_group * 7;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 7; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
T value = Compute(zipped_value, shift_bit, scale);
|
||||
out_ptr[(row_in_group + shift_bit_id) * TileColumns + col] = value;
|
||||
}
|
||||
}
|
||||
|
||||
int row_in_group_last = group_id * 64 + 63;
|
||||
T value_last = Compute(zipped_value_last, shift_bits[0], scale);
|
||||
out_ptr[row_in_group_last * TileColumns + col] = value_last;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int TileRows, int TileColumns, int NumThreads>
|
||||
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
|
||||
TileColumns, NumThreads> {
|
||||
using ZippedT = uint8_t;
|
||||
using ScaleComputeT = float;
|
||||
|
||||
static constexpr int32_t kGroupSize = 64;
|
||||
static constexpr int32_t kPackNum = 4;
|
||||
static constexpr int32_t kWeightMask = 0x3F;
|
||||
static constexpr int32_t kLocalScaleMask = 0xF;
|
||||
static constexpr int32_t kBZP = 32;
|
||||
|
||||
// weight [16, N] uint8_t
|
||||
// local_scale [1, N] uint8_t
|
||||
// code_scale [N] float
|
||||
// code_zp [N] float
|
||||
// super_scale [N] T
|
||||
|
||||
// code_scale, code_zp and super_scale
|
||||
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
|
||||
// zipped weights and local_scale
|
||||
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
|
||||
|
||||
struct Arguments {
|
||||
uint8_t *weight_ptr;
|
||||
uint8_t *local_scale_ptr;
|
||||
float *code_scale_ptr;
|
||||
float *code_zp_ptr;
|
||||
T *super_scale_ptr;
|
||||
|
||||
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
|
||||
|
||||
__device__ explicit Arguments(uint8_t *smem_ptr) {
|
||||
SetZippedPtrs(smem_ptr);
|
||||
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
|
||||
}
|
||||
|
||||
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
|
||||
SetZippedPtrs(zipped_smem_ptr);
|
||||
SetColumnWisePtrs(column_wise_smem_ptr);
|
||||
}
|
||||
|
||||
__device__ void SetZippedPtrs(uint8_t *zipped_smem_ptr) {
|
||||
weight_ptr = zipped_smem_ptr;
|
||||
local_scale_ptr = zipped_smem_ptr + (TileRows / 4) * TileColumns;
|
||||
}
|
||||
|
||||
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
|
||||
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
|
||||
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
|
||||
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
|
||||
}
|
||||
};
|
||||
|
||||
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
if (need_preload) {
|
||||
if (g_super_scale_ptr) {
|
||||
args->super_scale_ptr[col] = g_super_scale_ptr[col];
|
||||
} else {
|
||||
args->super_scale_ptr[col] = static_cast<T>(1);
|
||||
}
|
||||
|
||||
args->code_scale_ptr[col] = g_code_scale_ptr[col];
|
||||
args->code_zp_ptr[col] = g_code_zp_ptr[col];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
|
||||
int local_scale_offset = ls_row_id * in_stride + col;
|
||||
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < TileRows / 4; ++zipped_row) {
|
||||
int s_zipped_offset = zipped_row * TileColumns + col;
|
||||
int g_zipped_offset = zipped_row * 4 * in_stride + col;
|
||||
|
||||
args->weight_ptr[s_zipped_offset] = g_weight_ptr[g_zipped_offset];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void LoadAsync(const uint8_t *g_weight_ptr,
|
||||
const uint8_t *g_local_scale_ptr,
|
||||
const float *g_code_scale_ptr,
|
||||
const float *g_code_zp_ptr,
|
||||
const T *g_super_scale_ptr,
|
||||
Arguments *args, const int64_t in_stride, bool need_preload) {
|
||||
int tid = threadIdx.x;
|
||||
|
||||
constexpr int kBytesPerThread = 16; // 16B per thread
|
||||
|
||||
constexpr int weight_size = TileRows / 4 * TileColumns;
|
||||
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
|
||||
constexpr int code_scale_size = sizeof(float) * TileColumns;
|
||||
constexpr int code_zp_size = sizeof(float) * TileColumns;
|
||||
constexpr int super_scale_size = sizeof(T) * TileColumns;
|
||||
|
||||
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
|
||||
constexpr int total_tasks = total_size / kBytesPerThread;
|
||||
|
||||
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
|
||||
|
||||
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
|
||||
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
|
||||
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
|
||||
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
|
||||
|
||||
static_assert(TileColumns % weight_threads == 0,
|
||||
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
|
||||
|
||||
static_assert(TileColumns % local_scale_threads == 0,
|
||||
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
|
||||
|
||||
if (tid < weight_threads) {
|
||||
constexpr int weight_per_thread_size = weight_size / weight_threads;
|
||||
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads;
|
||||
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
|
||||
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
|
||||
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
|
||||
}
|
||||
} else if (need_preload) {
|
||||
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads;
|
||||
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
|
||||
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
|
||||
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
|
||||
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
|
||||
}
|
||||
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
|
||||
if (g_super_scale_ptr) {
|
||||
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
|
||||
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
|
||||
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kIterations; ++i) {
|
||||
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
|
||||
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
|
||||
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Compute(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
int32_t shift_bits[4] = {9, 6, 3, 0};
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int col = tid; col < TileColumns; col += NumThreads) {
|
||||
ScaleComputeT super_scale =
|
||||
static_cast<ScaleComputeT>(args.super_scale_ptr[col]);
|
||||
ScaleComputeT code_scale =
|
||||
static_cast<ScaleComputeT>(args.code_scale_ptr[col]);
|
||||
ScaleComputeT code_zp = static_cast<ScaleComputeT>(args.code_zp_ptr[col]);
|
||||
|
||||
#pragma unroll
|
||||
for (int group_id = 0; group_id < TileRows / 64; ++group_id) {
|
||||
int local_scale_offset = (group_id / 2) * TileColumns + col;
|
||||
int32_t local_scale =
|
||||
static_cast<int32_t>(args.local_scale_ptr[local_scale_offset]);
|
||||
|
||||
ScaleComputeT zipped_value[16];
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int zipped_offset = (group_id * 16 + zipped_row) * TileColumns + col;
|
||||
zipped_value[zipped_row] =
|
||||
static_cast<ScaleComputeT>(args.weight_ptr[zipped_offset]);
|
||||
}
|
||||
|
||||
int local_scale_shift = ((block_start_row / 64 + group_id + 1) & 1) * 4;
|
||||
int32_t shifted_local_scale =
|
||||
(local_scale >> local_scale_shift) & kLocalScaleMask;
|
||||
ScaleComputeT scale =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * super_scale;
|
||||
|
||||
#pragma unroll
|
||||
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
|
||||
static_cast<ScaleComputeT>(0.5)));
|
||||
|
||||
int row = group_id * 64 + zipped_row * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
int32_t shift_bit = shift_bits[shift_bit_id];
|
||||
int32_t shifted_value = (decode_value >> shift_bit) & kWeightMask;
|
||||
|
||||
ScaleComputeT value =
|
||||
static_cast<ScaleComputeT>(shifted_value - kBZP);
|
||||
out_ptr[(row + shift_bit_id) * TileColumns + col] =
|
||||
static_cast<T>(scale * value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
|
||||
const int64_t block_start_row) {
|
||||
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
|
||||
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
|
||||
constexpr int RowStride = NumThreads * N / TileColumns;
|
||||
constexpr int kNumIters = kNumWeightsPerThread / N;
|
||||
|
||||
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
|
||||
|
||||
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
|
||||
|
||||
int tid = threadIdx.x;
|
||||
int begin_col_id = (tid * N) % TileColumns;
|
||||
int begin_row_id = (tid * N) / TileColumns;
|
||||
|
||||
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
|
||||
|
||||
UnzipArray<uint8_t, N> local_scales =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
|
||||
|
||||
UnzipArray<uint8_t, N> zipped_values[2];
|
||||
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
|
||||
zipped_values[0] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
|
||||
UnzipArray<T, N> super_scales =
|
||||
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_scales =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
|
||||
UnzipArray<float, N> code_zps =
|
||||
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
|
||||
|
||||
// special for TileRows = 64
|
||||
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
|
||||
UnzipArray<ScaleComputeT, N> scales;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t shifted_local_scale =
|
||||
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
|
||||
scales[i] =
|
||||
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int iter_id = 0; iter_id < kNumIters; ++iter_id) {
|
||||
int zipped_row = begin_row_id + iter_id * RowStride;
|
||||
int row = zipped_row * 4;
|
||||
|
||||
if (iter_id < kNumIters - 1) {
|
||||
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
|
||||
zipped_values[(iter_id + 1) & 1] =
|
||||
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
|
||||
}
|
||||
|
||||
UnzipArray<T, N> outs[4];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int32_t decode_value =
|
||||
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
|
||||
+ code_zps[i] + decode_value_zp));
|
||||
|
||||
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
decode_value >>= 3;
|
||||
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
|
||||
outs[0][i] = static_cast<T>(scales[i] * value_0);
|
||||
outs[1][i] = static_cast<T>(scales[i] * value_1);
|
||||
outs[2][i] = static_cast<T>(scales[i] * value_2);
|
||||
outs[3][i] = static_cast<T>(scales[i] * value_3);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int shift_bit_id = 0; shift_bit_id < 4; ++shift_bit_id) {
|
||||
UnzipArray<T, N> *tmp_out_ptr = reinterpret_cast<UnzipArray<T, N> *>(
|
||||
out_ptr + (row + shift_bit_id) * TileColumns + begin_col_id);
|
||||
*tmp_out_ptr = outs[shift_bit_id];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
Reference in New Issue
Block a user