Optimize the performance of moe_expert_ffn_wint2 (#2990)

* Change wint2 to ColumnMajor.

Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2

* Unify default_wint2x_mma.

Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7

* Change wint2 to ColumnMajorTileInterleave.

Change-Id: I593cbe36f991c0c5044989d65f0014087587c624

* Enable async copy for B.

Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46

* Add wint2x Dequantizer

* Remove TileDequanterB related codes.

Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8

* Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.

Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca

* Implement Wint2ParamsAccessor to load extra quant params from global memory.

Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233

* Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.

Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca

* Use async copy for local_scale.

Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893

* Check and correct the load and dequantize of weights.

Change-Id: Ie8dca505b39987144964fe6407d465b3b5953790

* Change for performance tuning.

Change-Id: I1da026fb1d1533a9d70350c7ba23c27e896cfc29

* Optimize the global memory access size of local_scale reading.

Change-Id: I4cbe3a2ef5951723d415c2d3252ce912394beaf5

* Specialize mma_tensor_op for wint2 to enable fine-grained pipeline.

Change-Id: Icbb4d48f90a41136f42d6ffff42d68de32f408da

* Minor fix.

Change-Id: I14d4ac9d267ee05442a3b47f00c26bee13d79e6f

* optimizing dequant performance with LOP3

* optimizing dequant performance with LOP3

* Avoid redundant dequantization of local_scale and use bf16 as computing type.

Change-Id: I63239ebc8f8e4a92d6281af59840ba50600b4334

* Add Multiplier and remove some logs.

Change-Id: Ifa199d81e6aeb472d2247c63f85ef30213684bcd

* optimizing dequant performance with LOP3

* Use __byte_perm to implement int8 to float32 conversion for performance improvement.

* Use lop3 to optimize the dequantize of local_scale.

Change-Id: I6189759970cb5b8dcbef769724784b8a7533b63c

* Minor fix and remove some logs.

Change-Id: I6279ba9926d5041093b1c6aea200acf2e4c49d46

* Fix stages for test.

Change-Id: I6f7b7cac612ef2c678e9d49f5ffa60eb53d3ae29

* Fix stages for test and add clock64 to profile.

Change-Id: Iffaf7324beaa910ce9ee56f47ae289de98f1a267

* Use __byte_perm to replace shift-and-or operations for faster integer merging.

* Split the uint2b convert.

Change-Id: I78da672ce8968e21f685285140ba546a161521b4

* Optimize convert of unscale.

Change-Id: I6795da1cdf5e8ab38ddaa9836240921b5312913a

* Minor optimization.

Change-Id: I1800aec34c3f4621abb02658208108f54da44d88

* Optimize mma pipeline and refine codes.

Change-Id: Id3075cf7b88f2813a11ccd1d3b49c62c978f36b8

* Add missing support.

Change-Id: Id65b7bc2c25fbb1a5b232c6bc9fb8c9093f691a8

* Accelerate FP16 dequantization performance

* Support tile shape as Xx64x64.

Change-Id: Ib8fd37e1ba1d06f7d11f2956e7f1367b0a92bcac

* Remove debugging codes and minor optimization.

Change-Id: I6b79bd56a6e8dd823efc169967ecd3cc9a43baf4

* Fix offset bug.

Change-Id: Id7aeb91e99d6f51836f2aff22187b4f79607395e

* Fix typo.

Change-Id: I19dde93fc1c1f7e19605905c90dc46298e203952

* Restore some codes and remove some debugging logs.

Change-Id: I8d44daf82ad1c6f8174134d195e7b3fe9a3afdfb

---------

Co-authored-by: baoqiwen <baoqiwen@baidu.com>
This commit is contained in:
Yiqun Liu
2025-07-28 10:32:43 +08:00
committed by GitHub
parent fb410b5f4c
commit 8f426c1690
17 changed files with 2076 additions and 491 deletions

View File

@@ -133,10 +133,18 @@ public:
template <typename TypeA, typename Arch>
struct LayoutDetailsB<TypeA, uint2b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
{
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<TypeA>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value; // 64
private:
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint2b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8
public:
// using Layout = layout::ColumnMajor;
// static constexpr int ElementsPerAccess = 16; // at least 4-bytes
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint2b_t>::value; // 64
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename TypeA, typename Arch>

View File

@@ -18,14 +18,12 @@
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
@@ -378,38 +376,23 @@ template <
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
@@ -441,38 +424,23 @@ struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<half_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<half_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, half_t,
LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock

View File

@@ -19,7 +19,7 @@
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h"
namespace cutlass {
namespace gemm {
@@ -379,38 +379,23 @@ template <
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator,
false, CacheOpA, CacheOpB>;
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
@@ -442,38 +427,23 @@ struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmen
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<bfloat16_t>::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<bfloat16_t>::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, kStages, Operator,
false, CacheOpA, CacheOpB>;
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA>;
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock

View File

@@ -0,0 +1,182 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
/// Partial specialization:
///
/// A: row-major
/// B: uint2b_t, column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Number of stages
int Stages,
/// Operation performed by MMA
typename Operator_,
/// Cache operation of operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, uint2b_t, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = layout::RowMajor;
using ElementB = uint2b_t;
using LayoutB = layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
static int const kStages = Stages;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
/// Number of warps present
using WarpCount = GemmShape<Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size.");
/// Number of threads per warp
static int const kWarpSize = warp::WarpSize<arch::OpClassTensorOp>::value;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access of B
static constexpr int kMaxThreadsForB =
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
static constexpr int kThreadsForB =
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
static int const kWarpThreadArrangementContiguousB =
Shape::kK / (kAccessSizeInBits / sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, Shape::kK>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
IteratorThreadMapB>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,246 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/gemm/threadblock/default_mma_core.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename ThreadblockShape, typename ElementT, int GroupSize>
struct DefaultQuantParamsIterators {
private:
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
static constexpr int kColumns = ThreadblockShape::kN;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
IteratorThreadMap, kAlignment>;
using SmemIterator = Iterator;
};
template <typename ThreadblockShape, int GroupSize>
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
private:
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
static constexpr int kColumns =
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
public:
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
0, IteratorThreadMap, AccessType>;
using SmemIterator = Iterator;
};
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
struct DefaultWint2xMma;
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
/// Operator performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator, SharedMemoryClear>
{
public:
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint2b_t");
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using ElementSuperScale = ElementA;
using ElementLocalScale = uint4b_t;
using ElementCodeScaleZp = float;
static constexpr int kGroupSize = 64;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
using IteratorShapeB = MatrixShape<
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
ThreadMapB::kThreads,
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
WarpArrangement::kStrided / kColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
AccessTypeB>;
private:
// Define iterators over tiles from extra quant params for B operand
using IteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::Iterator;
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
using IteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
public:
using QuantParamsAccessor = Wint2ParamsAccessor<
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
IteratorLocalScale, SmemIteratorLocalScale,
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
typename MmaCore::Shape,
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
kStages, QuantParamsAccessor, SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -63,8 +63,8 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
/// Size of extra quantized params
typename QuantParamsShape>
class Wint2xMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
@@ -93,6 +93,14 @@ public:
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of warp-level GEMM oeprations per load for B
static constexpr int kWarpGemmIterationsPerLoadForB =
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
static constexpr int kWarpLoadIterationsForB =
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
/// Number of stages
static int const kStages = Stages;
@@ -104,8 +112,6 @@ public:
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
// using TensorRefZippedB = TensorRef<uint8_t, typename Operator::LayoutB>;
static_assert(kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
@@ -130,20 +136,11 @@ public:
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK + Policy::SmemPaddingB::kRow,
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
// w uint8; local_scale uint8;
constexpr static int kZippedRowsPerStages =
Shape::kK / 4 + (Shape::kK + 127) / 128;
// code_scale float; code_zp float; super_scale ElementB
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
sizeof_bits<typename Operator::ElementB>::value / 8;
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
/// Shape of all quant params in shared memory
using QuantParamsShapeB = QuantParamsShape;
public:
//
@@ -156,12 +153,8 @@ public:
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer for quanted B operand
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
/// Buffer for unzip B operand
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
operand_unzip_B;
/// Buffer for extra quant params of B operand
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
public:
//
@@ -191,14 +184,6 @@ public:
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
CUTLASS_HOST_DEVICE
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
CUTLASS_HOST_DEVICE
typename Operator::ElementB *operand_unzip_B_ptr() {
return operand_unzip_B.data();
}
};
protected:

View File

@@ -45,7 +45,8 @@
#include "cutlass_extensions/arch/memory_copy_sm80.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -86,15 +87,15 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Accessor for extra quantized params
typename QuantParamsAccessor_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages> {
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages>;
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
@@ -107,8 +108,11 @@ public:
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
/// Accessor for extra quantized params
using QuantParamsAccessor = QuantParamsAccessor_;
using QuantArguments = typename QuantParamsAccessor::Arguments;
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
@@ -129,6 +133,18 @@ public:
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
using LayoutScale = layout::RowMajor;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using WarpDequantizer =
warp::MmaTensorOpWin2xDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename WarpTransformedFragmentB::Element,
LayoutScale,
QuantParamsAccessor::kGroupSize>;
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
@@ -174,18 +190,37 @@ public:
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale;
using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp;
using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale;
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentA warp_loaded_frag_A_[2];
WarpTransformedFragmentA warp_transformed_frag_A_[2];
WarpTransformedFragmentA warp_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];
WarpLoadedFragmentB warp_loaded_frag_B_;
WarpTransformedFragmentB warp_frag_B_[2];
/// channel-wise quant params
FragmentCodeScaleZp warp_frag_code_scale_;
FragmentCodeScaleZp warp_frag_code_zp_;
FragmentSuperScale warp_frag_super_scale_;
/// group-wise quant params
FragmentLocalScale warp_frag_local_scale_;
};
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool IsTileInterleaveLayout =
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
@@ -202,17 +237,18 @@ public:
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Accessor for extra quant params for B
QuantParamsAccessor quant_params_accessor_B_;
// Wint2 unzip operator
WarpDequantizer warp_dequantizer_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
uint8_t* column_wise_smem_ptr_B_;
uint8_t* smem_zipped_ptr_B_;
int smem_zipped_bytes_per_stage_B_;
public:
/// Construct from tensor references
@@ -226,10 +262,15 @@ public:
int warp_idx,
///< ID of each thread within a warp
int lane_idx
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
quant_params_accessor_B_.local_scale_ref(),
quant_params_accessor_B_.code_scale_ref(),
quant_params_accessor_B_.code_zp_ref(),
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
@@ -250,11 +291,6 @@ public:
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
}
/// Advance shared memory read-iterators to the next stage
@@ -266,28 +302,22 @@ public:
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
// this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
smem_read_stage_idx_ = 0;
}
this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
template <typename TileDequanterB>
CUTLASS_DEVICE
void advance_smem_write_stage(
IteratorA &iterator_A,
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B)
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
{
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
//iterator_B.add_tile_offset({1, 0});
tile_dequanter_B.AddTileOffset({1, 0});
iterator_B.add_tile_offset({1, 0});
// Advance shared iterators
smem_iterator_A_.add_tile_offset({0, 1});
//smem_iterator_B_.add_tile_offset({1, 0});
smem_iterator_B_.add_tile_offset({1, 0});
// Increment shared memory write stage index
++smem_write_stage_idx_;
@@ -295,7 +325,7 @@ public:
if (smem_write_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
smem_iterator_A_.add_tile_offset({0, -Base::kStages});
//smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx_ = 0;
}
}
@@ -338,9 +368,14 @@ public:
}
}
template <bool GlobalToSharedB>
CUTLASS_DEVICE
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
@@ -360,13 +395,14 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, is_valid);
}
++iterator_B;
@@ -375,7 +411,6 @@ public:
++this->smem_iterator_B_;
}
}
__syncthreads();
}
CUTLASS_DEVICE
@@ -399,8 +434,6 @@ public:
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
@@ -411,9 +444,12 @@ public:
}
}
template <bool GlobalToSharedB, bool InitStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
@@ -433,35 +469,23 @@ public:
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (InitStage) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
} else {
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::copy<kSrcBytes, kCacheOpB, GlobalToSharedB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
}
++iterator_B;
}
++this->smem_iterator_B_;
}
__syncthreads();
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
template <typename TileDequanterB>
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B,
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
{
// Issue several complete stages
@@ -476,11 +500,18 @@ public:
copy_tiles_and_advance_per_stage_A(iterator_A);
// Async copy zipped B to shared memory.
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
copy_tiles_and_advance_per_stage_B(iterator_B);
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
if (stage == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
} else {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
}
// Move to the next write stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
@@ -510,6 +541,10 @@ public:
++last_smem_iterator_A;
}
if (threadIdx.x >= IteratorB::ThreadMap::kThreads) {
return;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
@@ -542,57 +577,57 @@ public:
}
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
template <typename TileDequanterB>
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
const int mma_stage = stage - Base::kStages + 1;
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
// CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k);
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
}
// load next-tile of group-wise local_scale from shared memory
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
}
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
// Unpack and dequant the first stage of B.
int unpack_stage = stage - Base::kStages + 2;
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, unpack_stage);
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, false>(iterator_B);
}
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
}
// dequantizes next warp-tile
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
// Execute the current warp-tile of MMA operations
if (Detail::kStagedAccumulation) {
if constexpr (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
@@ -604,22 +639,22 @@ public:
} else {
warp_mma_(
accum,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
accum
);
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
accum);
}
// Except for the last warp-tile, all warp-tiles issue their share of
// global->shared fragment copies
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
if (warp_mma_k == 0) {
tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
column_wise_smem_ptr_B_, stage);
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
}
}
@@ -628,9 +663,15 @@ public:
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
}
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
}
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
@@ -639,69 +680,66 @@ public:
gmem_wait();
// Move to the next global fetch stage
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
advance_smem_write_stage(iterator_A, iterator_B);
quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args);
advance_smem_read_stage();
int byte_offset = quant_params_accessor_B_.advance_smem_read_stage();
warp_dequantizer_.add_pointer_offset(byte_offset);
// Disable global fetching when done with global fetch iterations
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
}
// The last warp-tile also converts the shared memory fragments used by
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
}
}
}
/// Perform the specified number of threadblock mainloop iterations of matrix
/// multiply-accumulate. Assumes prologue has been initiated.
template <typename TileDequanterB>
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B,
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args)
{
PipeState pipe_state;
// Unpack and dequant the first stage of B.
tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0);
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Copy dequatized data to shared memory used by mma core.
copy_tiles_and_advance_per_stage_B<false, true>(iterator_B);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
// Load first warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
// Transform, if necessary, the first warp-tile's shared memory fragments
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[0],
pipe_state.warp_transformed_frag_B_[0],
pipe_state.warp_loaded_frag_A_[0],
pipe_state.warp_loaded_frag_B_[0]);
warp_dequantizer_.load(pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_);
if (Detail::kStagedAccumulation) {
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
// Load first warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]);
++this->warp_tile_iterator_A_;
// Dequantize B to in register
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[0],
0,
0);
if constexpr (Detail::kStagedAccumulation) {
pipe_state.tmp_accum_.clear();
}
@@ -715,13 +753,13 @@ public:
accum,
iterator_A,
iterator_B,
tile_dequanter_B,
mma_quant_args,
gemm_k_iterations,
stage);
stage += 1;
}
if (Detail::kStagedAccumulation) {
if constexpr (Detail::kStagedAccumulation) {
plus<FragmentC> plus_accum;
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
@@ -761,14 +799,12 @@ public:
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
//this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
template <typename TileDequanterB>
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
@@ -779,13 +815,13 @@ public:
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< pre-load and dequantize B to shared memory
TileDequanterB tile_dequanter_B,
///< iterators for extra quant params for B
QuantArguments mma_quant_args,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
gmem_wait();
@@ -794,7 +830,7 @@ public:
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
}
};

View File

@@ -0,0 +1,315 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/trace.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Original data type
typename T,
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterators over super scales in global memory
typename IteratorSuperScale_,
/// Iterators over super scales in shared memory
typename SmemIteratorSuperScale_,
/// Iterators over local scales in global memory
typename IteratorLocalScale_,
/// Iterators over local scales in shared memory
typename SmemIteratorLocalScale_,
/// Iterators over code scales and zps in global memory
typename IteratorCodeScaleZp_,
/// Iterators over code scales and zps in shared memory
typename SmemIteratorCodeScaleZp_,
/// Number of stages,
int Stages_,
/// Group size for quantization
int GroupSize_>
class Wint2ParamsAccessor {
public:
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
using ElementType = T;
using Shape = Shape_;
using IteratorSuperScale = IteratorSuperScale_;
using SmemIteratorSuperScale = SmemIteratorSuperScale_;
using IteratorLocalScale = IteratorLocalScale_;
using SmemIteratorLocalScale = SmemIteratorLocalScale_;
using IteratorCodeScaleZp = IteratorCodeScaleZp_;
using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_;
constexpr static int kStages = Stages_;
constexpr static int kGroupSize = GroupSize_;
using ElementSuperScale = typename IteratorSuperScale::Element;
using LayoutSuperScale = typename IteratorSuperScale::Layout;
/// local_scale uint4 and group-wise
using ElementLocalScale = typename IteratorLocalScale::Element;
using LayoutLocalScale = typename IteratorLocalScale::Layout;
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
"local_scale's type must be uint4b_t.");
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
/// 2 uint4b_t values are stored in a single uint8_t
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
constexpr static int kLocalScaleRows =
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
using SmemElement = uint8_t;
constexpr static int kSmemRows =
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
constexpr static int kSmemColumns = Shape::kN;
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
constexpr static int kSuperScaleSmemOffset = 0;
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
struct Arguments {
IteratorSuperScale iterator_super_scale;
IteratorLocalScale iterator_local_scale;
IteratorCodeScaleZp iterator_code_scale;
IteratorCodeScaleZp iterator_code_zp;
int local_scale_pointer_offset;
CUTLASS_DEVICE
Arguments(IteratorSuperScale iterator_super_scale,
IteratorLocalScale iterator_local_scale,
IteratorCodeScaleZp iterator_code_scale,
IteratorCodeScaleZp iterator_code_zp,
int local_scale_pointer_offset)
: iterator_super_scale(iterator_super_scale),
iterator_local_scale(iterator_local_scale),
iterator_code_scale(iterator_code_scale),
iterator_code_zp(iterator_code_zp),
local_scale_pointer_offset(local_scale_pointer_offset) {}
};
private:
//
// Data members
//
/// Begin address of shared memory
uint8_t* smem_pointer_;
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
SmemIteratorSuperScale smem_iterator_super_scale_;
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
SmemIteratorLocalScale smem_iterator_local_scale_;
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
/// Shared memory write stage index
int smem_write_stage_idx_;
/// Shared memory read stage index
int smem_read_stage_idx_;
CUTLASS_DEVICE
ElementSuperScale* get_super_scale_smem_ptr() {
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
}
CUTLASS_DEVICE
ElementLocalScale* get_local_scale_smem_ptr() {
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_scale_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_zp_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
}
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2ParamsAccessor(
///< prointer of shared memory
uint8_t* smem_pointer,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: smem_pointer_(smem_pointer),
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {}
CUTLASS_DEVICE
SuperTensorRef super_scale_ref() {
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
}
CUTLASS_DEVICE
LocalTensorRef local_scale_ref() {
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_scale_ref() {
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_zp_ref() {
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
template <bool IsFirstStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
if constexpr (IsFirstStage) {
// Load channel-wise super_scale to shared memory, which only needs to be done once.
typename IteratorSuperScale::Fragment tb_frag_super_scale;
tb_frag_super_scale.clear();
quant_args.iterator_super_scale.load(tb_frag_super_scale);
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
// Load channel-wise code_scale to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
tb_frag_code_scale.clear();
quant_args.iterator_code_scale.load(tb_frag_code_scale);
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
// Load channel-wise code_zp to shared memory, which only needs to be done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
tb_frag_code_zp.clear();
quant_args.iterator_code_zp.load(tb_frag_code_zp);
this->smem_iterator_code_zp_.store(tb_frag_code_zp);
}
if ((stage % kStagesPerLocalScaleLoad) == 0) {
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
using AccessType = typename IteratorLocalScale::AccessType;
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
quant_args.iterator_local_scale.set_iteration_index(0);
this->smem_iterator_local_scale_.set_iteration_index(0);
// Async Copy for local_scale
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
AccessType *dst_ptr =
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
auto gmem_ptr = quant_args.iterator_local_scale.get();
int const kSrcBytes =
sizeof_bits<typename IteratorLocalScale::Element>::value *
IteratorLocalScale::ThreadMap::kElementsPerAccess /
IteratorLocalScale::kAccessesPerVector / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
}
++quant_args.iterator_local_scale;
}
++this->smem_iterator_local_scale_;
}
}
CUTLASS_DEVICE
void advance_smem_write_stage(Arguments &quant_args) {
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
// Advance global iterators
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
// Advance shared iterators
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
}
// Increment shared memory write stage index
++smem_write_stage_idx_;
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
smem_write_stage_idx_ = 0;
}
}
CUTLASS_DEVICE
int advance_smem_read_stage() {
int byte_offset = 0;
++smem_read_stage_idx_;
if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
byte_offset = kLocalScaleRows * kSmemColumns;
}
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
smem_read_stage_idx_ = 0;
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
}
return byte_offset;
}
CUTLASS_DEVICE
int clear_mask(Arguments &quant_args, bool cond) {
quant_args.iterator_local_scale.clear_mask(cond);
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -1,130 +0,0 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/gemm_coord.h"
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <typename ElementT, typename ScaleElementT, int Rows, int Columns,
int Stages, int NumThreads, WintQuantMethod Method>
struct TileDequanter {
using WeightQuantTraits = WintQuantTraits<ElementT, Method>;
using MmaElementT = typename WeightQuantTraits::MmaWeightType;
using QuantArguments = typename WeightQuantTraits::Arguments;
using UnzipAndDequantFunctor =
UnzipAndDequantFunctor<MmaElementT, Method, Rows, Columns, NumThreads>;
static constexpr bool kUseSharedMemory = true;
static constexpr int kRows = Rows;
static constexpr int kColumns = Columns;
static constexpr int kStages = Stages;
MmaElementT *out_smem_ptr{nullptr};
char *pointer{nullptr};
int64_t ldm{0};
cutlass::MatrixCoord tb_offset;
cutlass::MatrixCoord extent;
ScaleElementT *super_scale_ptr{nullptr};
cutlass::MatrixCoord tb_offset_scale;
QuantArguments quant_args;
int64_t block_start_rows[kStages];
bool need_preload{true};
UnzipAndDequantFunctor unzip_functor;
CUTLASS_DEVICE
TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm,
const cutlass::MatrixCoord &extent,
const cutlass::MatrixCoord &tb_offset,
ScaleElementT *super_scale_ptr,
const cutlass::MatrixCoord &tb_offset_scale,
const QuantArguments &quant_args)
: out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent),
tb_offset(tb_offset), super_scale_ptr(super_scale_ptr),
tb_offset_scale(tb_offset_scale), quant_args(quant_args) {}
CUTLASS_DEVICE
MmaElementT *GetOutPtr() { return out_smem_ptr; }
CUTLASS_DEVICE
void AddTileOffset(const cutlass::MatrixCoord &tile_offset) {
tb_offset.row() += tile_offset.row() * kRows;
tb_offset.column() += tile_offset.column() * kColumns;
tb_offset_scale.column() += tile_offset.column() * kColumns;
}
CUTLASS_DEVICE
void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row());
if (tb_offset.row() >= extent.row() ||
tb_offset.column() >= extent.column()) {
return;
}
block_start_rows[stage % kStages] = tb_offset.row();
using ZippedT = typename WeightQuantTraits::WeightType;
ZippedT *in_ptr = reinterpret_cast<ZippedT *>(pointer) + zipped_row * ldm +
tb_offset.column();
ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column();
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
const uint8_t *local_scale_ptr = quant_args.local_scale_ptr +
(tb_offset.row() / 128) * ldm +
tb_offset_scale.column();
const float *code_scale_ptr =
quant_args.code_scale_ptr + tb_offset_scale.column();
const float *code_zp_ptr =
quant_args.code_zp_ptr + tb_offset_scale.column();
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr,
scale_ptr, &args, ldm, need_preload);
need_preload = false;
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
CUTLASS_DEVICE
void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) {
int64_t block_start_row = block_start_rows[stage % kStages];
if (block_start_row >= extent.row()) {
return;
}
if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) {
typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr);
unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row);
} else {
// CUTLASS_TRACE_DEVICE("Not Supported!");
}
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@@ -41,12 +41,9 @@
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass
{
namespace gemm
{
namespace warp
{
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -81,7 +78,7 @@ private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
// Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2.
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory

View File

@@ -58,15 +58,12 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace warp
{
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
@@ -297,6 +294,235 @@ public:
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer.
/// Specialization for B of uint2b_t.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
class MmaTensorOpComputeBWithF16<
Shape_,
ElementA_,
LayoutA_,
uint2b_t,
LayoutB_,
ElementC_,
LayoutC_,
Policy_,
SharedMemoryInstructionShape_,
PartitionsK_,
AccumulatorsInRowMajor>
{
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = uint2b_t;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
&& ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
static_assert(platform::is_same<ElementA, half_t>::value
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
static_assert(
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
kThreadCount, kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements / kExpansionFactor>;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const
{
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m)
{
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n)
{
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor)
{ // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
}
else
{
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp

View File

@@ -0,0 +1,442 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations
targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace warp {
namespace detail {
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<bfloat16_t> {
using Type = __nv_bfloat16;
using DualType = __nv_bfloat162;
};
template <>
struct DataTypeTraits<half_t> {
using Type = __half;
using DualType = __half2;
};
template <typename T, int N, typename Enable = void>
struct LocalScaleConverter {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<T, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t kLocalScaleMask = 0xf;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
int32_t shifted_value = (static_cast<int32_t>(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask;
scale_frag[i] = static_cast<T>(shifted_value) * super_scale_frag[i];
}
}
};
template <int N>
struct LocalScaleConverter<half_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<half_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
constexpr uint32_t MASK = 0x000f000f;
// 2^10 = 1024
constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400;
// -2^10 = -1024
constexpr uint32_t FP16_BIAS = 0xE400E400;
// 1.0
constexpr uint32_t FP16_ONE = 0x3C003C00;
__half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag);
__half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_FP16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_FP16s_MAGIC_NUM);
__half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
__half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1),
*reinterpret_cast<const __half2*>(&FP16_ONE),
*reinterpret_cast<const __half2*>(&FP16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
}
};
template <int N>
struct LocalScaleConverter<bfloat16_t, N, typename platform::enable_if<N % 4 == 0>::type> {
using FragmentSource = Array<uint8_t, N>;
using FragmentResult = Array<bfloat16_t, N>;
CUTLASS_DEVICE
static void Apply(FragmentSource const& local_scale_frag,
FragmentResult const& super_scale_frag,
FragmentResult& scale_frag,
int shift_bit) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
constexpr uint32_t MASK = 0x000F000F;
constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
constexpr uint32_t BF16_BIAS = 0xC300C300;
constexpr uint32_t BF16_ONE = 0x3F803F80;
__nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag);
__nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag);
uint32_t const* local_scale_ptr = reinterpret_cast<uint32_t const*>(&local_scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
int i4s = local_scale_ptr[i] >> shift_bit;
// unpack: 0, 1
int32_t low = __byte_perm(i4s, i4s, 0xF1F0);
int32_t unpack0 = lop3<immLut>(low, MASK, I4s_TO_BF16s_MAGIC_NUM);
// unpack: 2, 3
int32_t high = __byte_perm(i4s, i4s, 0xF3F2);
int32_t unpack1 = lop3<immLut>(high, MASK, I4s_TO_BF16s_MAGIC_NUM);
nv_bfloat162 scale0 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack0),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
nv_bfloat162 scale1 = __hfma2(*reinterpret_cast<nv_bfloat162*>(&unpack1),
*reinterpret_cast<const nv_bfloat162*>(&BF16_ONE),
*reinterpret_cast<const nv_bfloat162*>(&BF16_BIAS));
scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]);
scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]);
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
};
} // namespace detail
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename ElementOperand_,
/// Layout of operand
typename Layout_,
/// Group size for quantization
int GroupSize_,
///
typename Enable = void>
class MmaTensorOpWin2xDequantizer {
//static_assert(false, "Not Supported!");
};
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_,
/// Data type of Scale elements
typename ElementOperand_,
/// Group size for quantization
int GroupSize_>
class MmaTensorOpWin2xDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
ElementOperand_,
layout::RowMajor,
GroupSize_>
//typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
// && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
{
public:
static_assert(platform::is_same<ElementOperand_, half_t>::value || platform::is_same<ElementOperand_, bfloat16_t>::value,
"T must be fp16 or bf16");
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
/// Warp mma shape
using Shape = Shape_;
/// Type of mma operand
using ElementOperand = ElementOperand_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// Group size for quantization
static constexpr int kGroupSize = GroupSize_;
/// Type of input
using ElementB = typename MmaOperator::FragmentB::Element;
static_assert(platform::is_same<ElementB, uint2b_t>::value, "ElementB must be uint2b_t");
/// Type of the scales
using ElementLocalScale = uint4b_t;
using ElementSuperScale = ElementOperand;
using ElementCodeScaleZp = float;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn;
// use uint8_t to save 2 4-bits local scales
using FragmentLocalScale = Array<uint8_t, kWarpIterationsAlongN>;
using FragmentSuperScale = Array<ElementSuperScale, kWarpIterationsAlongN>;
using FragmentCodeScaleZp = Array<ElementCodeScaleZp, kWarpIterationsAlongN>;
/// Fragment to hold B data before Mma
using FragmentInput = Array<ElementB, MmaOperator::FragmentB::kElements>;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
static constexpr int kNumPacks = sizeof_bits<uint8_t>::value / sizeof_bits<ElementB>::value;
static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks);
static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor;
/// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points.
using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>;
using FragmentInputUnpack = typename Uint2Converter::result_type;
/// Fragment to hold internal scales before Mma
using FragmentScale = Array<ElementOperand, FragmentLocalScale::kElements>;
/// Fragment of dequantized B
using FragmentOutput = Array<ElementOperand, MmaOperator::FragmentB::kElements / kExpansionFactor>;
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, Layout>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, Layout>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, Layout>;
private:
//
// Data members
//
uint8_t* pointer_local_scale_;
ElementCodeScaleZp* pointer_code_scale_;
ElementCodeScaleZp* pointer_code_zp_;
ElementSuperScale* pointer_super_scale_;
//FragmentInputUnpack unpacked_frag_;
FragmentScale scale_frag_;
public:
CUTLASS_DEVICE
MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale,
LocalTensorRef smem_local_scale,
CodeTensorRef smem_code_scale,
CodeTensorRef smem_code_zp,
int warp_idx_n,
int lane_idx) {
int warp_offset = warp_idx_n * Shape::kN;
int quad = lane_idx / 4;
int thread_offset = warp_offset + quad;
pointer_super_scale_ = smem_super_scale.data() + thread_offset;
pointer_code_scale_ = smem_code_scale.data() + thread_offset;
pointer_code_zp_ = smem_code_zp.data() + thread_offset;
pointer_local_scale_ = reinterpret_cast<uint8_t *>(smem_local_scale.data()) + thread_offset;
}
/// Channel-wise params, need to load just once
CUTLASS_DEVICE
void load(FragmentCodeScaleZp& code_scale_frag,
FragmentCodeScaleZp& code_zp_frag,
FragmentSuperScale& super_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN];
code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN];
}
}
/// Group-wise params, need to load multiple times
CUTLASS_DEVICE
void load(FragmentLocalScale& local_scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict
}
}
CUTLASS_DEVICE
void dequantize(const FragmentLocalScale& local_scale_frag,
const FragmentCodeScaleZp& code_scale_frag,
const FragmentCodeScaleZp& code_zp_frag,
const FragmentSuperScale& super_scale_frag,
const FragmentInput& input_frag,
FragmentOutput& output_frag,
int tb_offset_k,
int warp_k_compute_offset) {
if constexpr (kUnpackInterval != 1) {
// unsupport now
arch::device_breakpoint();
}
typename Uint2Converter::source_type source_frag;
int in_offset = warp_k_compute_offset * kUnpackInterval;
uint8_t const* ptr_input = reinterpret_cast<uint8_t const*>(&input_frag);
uint8_t* ptr_source = reinterpret_cast<uint8_t *>(&source_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) {
ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset];
}
FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag);
// dequantize local_scale
if (warp_k_compute_offset == 0) {
using LocalScaleConverter = detail::LocalScaleConverter<ElementOperand, FragmentLocalScale::kElements>;
// special for TileRows = 64
int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4;
LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift);
}
// unscale
// After applying LOP3 optimizations for performance, the B operand requires data rearrangement.
// reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15]
const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN;
using Type = typename detail::DataTypeTraits<ElementOperand>::Type;
using DualType = typename detail::DataTypeTraits<ElementOperand>::DualType;
Type* output_ptr = reinterpret_cast<Type *>(&output_frag);
DualType const* unpacked_ptr = reinterpret_cast<DualType const*>(&unpacked_frag);
DualType const* scale_ptr = reinterpret_cast<DualType const*>(&scale_frag_);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) {
int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK;
DualType scalex2 = scale_ptr[mma_n_iter / 2];
CUTLASS_PRAGMA_UNROLL
for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) {
DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter];
DualType scaled_value = __hmul2(unpacked_valuex2, scalex2);
output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x;
output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y;
}
}
}
/// Add an offset to pointer in units of elements.
/// Only group-wise params needs.
CUTLASS_DEVICE
void add_pointer_offset(int64_t const& offset) {
pointer_local_scale_ += offset;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////

View File

@@ -39,18 +39,25 @@
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
#include "cutlass/trace.h"
namespace cutlass
{
namespace cutlass {
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
// This converter will uninterleave the data and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter
{
};
struct FastInterleavedAndBiasedNumericArrayConverter;
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
@@ -440,6 +447,329 @@ struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
{
using result_type = Array<half_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^10 = 1024
static constexpr uint32_t EX = 0x64006400;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
// 1024 + 32 = 1056
static constexpr uint32_t SUB = 0x64206420;
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
{
using result_type = Array<bfloat16_t, 16>;
using source_type = Array<uint2b_t, 16>;
using ScaleComputeT = float;
using code_type = Array<ScaleComputeT, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
ScaleComputeT new_code_zp = code_zp + 0.5f;
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
// 2^23 = 8388608
static constexpr uint32_t FP32_BASE = 0x4B000000;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
int32_t decode_value[4];
decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
return convert_impl(decode_value);
}
CUTLASS_DEVICE
static result_type convert_impl(int32_t* decode_value)
{
result_type result;
static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
static constexpr uint32_t MASK = 0x003F003F;
// 2^7 = 128
static constexpr uint32_t EX = 0x43004300;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
h[3] = lop3<immLut>(q0, MASK, EX);
h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
h[7] = lop3<immLut>(q1, MASK, EX);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
// 128 + 32 = 160
static constexpr uint32_t SUB = 0x43204320;
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
#else
// 1.0
static constexpr uint32_t MUL = 0x3F803F80;
// -160
static constexpr uint32_t ADD = 0xC320C320;
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
{
return convert(s, code_scale, code_zp);
}
};
template <typename T, int N>
struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
{
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
static constexpr int kVecWidth = 16;
static_assert(!(N % kVecWidth), "N must be multiple of 16.");
using result_type = Array<T, N>;
using source_type = Array<uint2b_t, N>;
using code_type = Array<float, N / kVecWidth>;
CUTLASS_DEVICE
static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, kVecWidth>;
using vec_source = Array<scalar_source_type, kVecWidth>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
}
return result;
}
CUTLASS_DEVICE
static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
{
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
result_type result;
using vec_result = typename Converter::result_type;
using vec_source = typename Converter::source_type;
using vec_code = typename Converter::code_type;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / kVecWidth; ++i)
{
result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
{
return convert(s, code_scale, code_zp);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass

View File

@@ -125,10 +125,13 @@ struct WintQuantTraits<ElementT, WintQuantMethod::kWeightOnlyInt2> {
static constexpr int32_t kNumPackedValues = 4;
static constexpr int32_t kPackedSize = 16;
using LocalScaleType = uint4b_t;
using CodeScaleZpType = float;
struct Arguments {
const uint8_t *local_scale_ptr; // quanted 4-bits
const float *code_scale_ptr;
const float *code_zp_ptr;
uint8_t *local_scale_ptr; // quanted 4-bits
float *code_scale_ptr;
float *code_zp_ptr;
};
CUTLASS_DEVICE

View File

@@ -43,7 +43,6 @@
#include "cutlass/trace.h"
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
#include "cutlass_extensions/tile_interleaved_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -775,17 +774,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
template <WintQuantMethod QuantMethod, typename dummy>
struct KernelRunner<QuantMethod, true, dummy> {
using WeightQuantTraits = WintQuantTraits<ElementA, QuantMethod>;
using QuantArguments = typename WeightQuantTraits::Arguments;
using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments;
CUTLASS_DEVICE
static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) {
QuantArguments quant_args;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128;
quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n;
quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n;
}
return quant_args;
static MmaQuantArguments prepare_quant_args(
Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset,
int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) {
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2};
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale(
Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n),
weight_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2);
int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128;
uint4b_t *local_scale_ptr = reinterpret_cast<uint4b_t *>(params.local_scale + offset_in_bytes);
typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale(
Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2),
local_scale_ptr,
{(gemm_k + 127) / 128, gemm_n * 2},
thread_idx,
tb_offset_local_scale);
float* code_scale_ptr = params.code_scale + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_scale_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
float* code_zp_ptr = params.code_zp + problem_idx * gemm_n;
typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp(
Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n),
code_zp_ptr,
{1, gemm_n},
thread_idx,
tb_offset_scale);
MmaQuantArguments mma_quant_args(
iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset);
return mma_quant_args;
}
CUTLASS_DEVICE
@@ -814,9 +850,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// LayoutB should be RowMajor
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
//
// Problem visitor.
//
@@ -843,12 +876,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
int(cta_idx % grid_shape.n()) * Mma::Shape::kN, // NOLINT
0);
// begin address offset for weight_scale.
ElementScale* weight_scale_ptr =
params.weight_scales ? params.weight_scales + problem_idx * problem_size.n() : nullptr;
// the begin threadblock_offset of scale, which holds the same column id with C, but with no row id
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Load element pointers. Exchange pointers and strides if working on
// the transpose
int64_t rows_to_jump = 0;
@@ -866,42 +893,20 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
// Compute initial location in logical coordinates
// the begin threadblock_offset of A, which holds the same row id with C
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_A{threadblock_offset.m(), 0};
// begin address offset for B for current problem_idx, totally num_experts problems
char* byte_ptr_B = ((char*)params.ptr_B) + // NOLINT
problem_idx * bytes_per_expert_matrix; // NOLINT
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B =
platform::is_same<layout::RowMajor, LayoutB>::value
? gemm_n
: gemm_k * kInterleave;
typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
// the begin threadblock_offset of B, which holds the same column id with C
cutlass::MatrixCoord tb_offset_B{0,
threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
MmaElementB* smem_unzip_B_ptr = nullptr;
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
}
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
byte_ptr_B,
ldm_B,
extent_B,
tb_offset_B,
weight_scale_ptr,
tb_offset_scale,
quant_args);
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();
// Compute position within threadblock
int thread_idx = threadIdx.x;
@@ -914,20 +919,21 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
tb_offset_A);
typename Mma::IteratorB iterator_B(
LayoutB(TileDequanterB::kUseSharedMemory ? ldm_B_shared : ldm_B),
LayoutB(ldm_B),
ptr_B,
TileDequanterB::kUseSharedMemory ? extent_B_shared : extent_B,
extent_B,
thread_idx,
TileDequanterB::kUseSharedMemory ? cutlass::make_Coord(0, 0) : tb_offset_B);
tb_offset_B);
MmaQuantArguments mma_quant_args = prepare_quant_args(
params, threadblock_offset, problem_idx, gemm_k, gemm_n, thread_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
@@ -950,7 +956,7 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
accumulators,
iterator_A,
iterator_B,
tile_dequanter_B,
mma_quant_args,
accumulators);
//

View File

@@ -205,7 +205,7 @@ void generic_moe_gemm_kernelLauncher(const T* A,
threadblock_count,
epilogue_op,
reinterpret_cast<const ElementType*>(A),
reinterpret_cast<const CutlassMmaWeightType*>(B),
reinterpret_cast<const CutlassMmaKernelType*>(B),
reinterpret_cast<const ElementType*>(weight_scales),
reinterpret_cast<const ElementType*>(biases),
reinterpret_cast<ElementType*>(C),

View File

@@ -49,12 +49,13 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data<uint8_t>();
up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data<float>();
up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data<float>();
down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data<uint8_t>();
down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data<float>();
down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data<float>();
up_gate_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(up_gate_proj_local_scale->data<uint8_t>());
up_gate_proj_quant_args.code_scale_ptr = const_cast<float*>(up_gate_proj_code_scale->data<float>());
up_gate_proj_quant_args.code_zp_ptr = const_cast<float*>(up_gate_proj_code_zp->data<float>());
down_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(down_proj_local_scale->data<uint8_t>());
down_proj_quant_args.code_scale_ptr = const_cast<float*>(down_proj_code_scale->data<float>());
down_proj_quant_args.code_zp_ptr = const_cast<float*>(down_proj_code_zp->data<float>());
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();