diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h index 40f128b7a..8f61c6d9c 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -133,10 +133,18 @@ public: template struct LayoutDetailsB= 75>::type> { - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::RowMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; // 64 + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; // 8 + +public: + // using Layout = layout::ColumnMajor; + // static constexpr int ElementsPerAccess = 16; // at least 4-bytes + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; // 64 + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; }; template diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h index bc395d04d..b50d66380 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma.h @@ -18,14 +18,12 @@ #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" #include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ +namespace cutlass { +namespace gemm { +namespace threadblock { //////////////////////////////////////////////////////////////////////////////// @@ -378,38 +376,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -441,38 +424,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h index 5d2c31170..300261c3f 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -19,7 +19,7 @@ #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" #include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_wint2x_mma.h" namespace cutlass { namespace gemm { @@ -379,38 +379,23 @@ template < struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; template < @@ -442,38 +427,23 @@ struct DefaultMma { - static cutlass::arch::CacheOperation::Kind const CacheOpA = - ((sizeof_bits::value * kAlignmentA) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = - ((sizeof_bits::value * kAlignmentB) == 128) ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; +private: + using Mma = DefaultWint2xMma; +public: // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; + using MmaCore = typename Mma::MmaCore; // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA>; + using IteratorA = typename Mma::IteratorA; // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB>; + using IteratorB = typename Mma::IteratorB; // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage; + using ThreadblockMma = typename Mma::ThreadblockMma; }; } // namespace threadblock diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h new file mode 100644 index 000000000..e2bc640ba --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_mma_core.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Partial specialization: +/// +/// A: row-major +/// B: uint2b_t, column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = uint2b_t; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access of B + static constexpr int kMaxThreadsForB = + (Shape::kK * Shape::kN * sizeof_bits::value) / kAccessSizeInBits; + static constexpr int kThreadsForB = + kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreadsForB, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h new file mode 100644 index 000000000..1782330de --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h @@ -0,0 +1,246 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_core.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 128 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize; + static constexpr int kColumns = ThreadblockShape::kN; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using Iterator = cutlass::transform::threadblock::PredicatedTileIterator< + MatrixShape, ElementT, layout::RowMajor, 0, + IteratorThreadMap, kAlignment>; + using SmemIterator = Iterator; +}; + +template +struct DefaultQuantParamsIterators { +private: + static constexpr int kAlignment = 32 / sizeof_bits::value; + static_assert((ThreadblockShape::kN % kAlignment) == 0, ""); + + static constexpr int kRows = + (GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize); + static constexpr int kColumns = + (GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2; + + using IteratorThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kColumns / kAlignment, kAlignment>; + +public: + using AccessType = cutlass::Array; + using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator< + MatrixShape, uint4b_t, layout::RowMajor, + 0, IteratorThreadMap, AccessType>; + + using SmemIterator = Iterator; +}; + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> +struct DefaultWint2xMma; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultWint2xMma +{ +public: + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value, + "Element B must be uint2b_t"); + + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + using ElementSuperScale = ElementA; + using ElementLocalScale = uint4b_t; + using ElementCodeScaleZp = float; + + static constexpr int kGroupSize = 64; + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int kRowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved"); + static_assert(kRowsPerTile == MmaCore::Shape::kK, ""); + + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement; + static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), ""); + + using IteratorShapeB = MatrixShape< + MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>; + using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + ThreadMapB::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB, + AccessTypeB>; + +private: + // Define iterators over tiles from extra quant params for B operand + using IteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::Iterator; + using SmemIteratorSuperScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementSuperScale, -1>::SmemIterator; + + using IteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator; + using SmemIteratorLocalScale = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator; + + using IteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators< + ThreadblockShape, ElementCodeScaleZp, -1>::Iterator; + +public: + using QuantParamsAccessor = Wint2ParamsAccessor< + ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale, + IteratorLocalScale, SmemIteratorLocalScale, + IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage< + typename MmaCore::Shape, + IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB, + ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, + kStages, QuantParamsAccessor, SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h index 6dd55b647..4b7d3ac06 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h @@ -63,8 +63,8 @@ template < typename Policy_, /// Number of stages, int Stages, - /// Used for partial specialization - typename Enable = bool> + /// Size of extra quantized params + typename QuantParamsShape> class Wint2xMmaBase { public: ///< Size of the Gemm problem - concept: gemm::GemmShape<> @@ -93,6 +93,14 @@ public: static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + /// Number of warp-level GEMM oeprations per load for B + static constexpr int kWarpGemmIterationsPerLoadForB = + Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), ""); + + static constexpr int kWarpLoadIterationsForB = + kWarpGemmIterations / kWarpGemmIterationsPerLoadForB; + /// Number of stages static int const kStages = Stages; @@ -104,8 +112,6 @@ public: using TensorRefB = TensorRef; - // using TensorRefZippedB = TensorRef; - static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " "GEMM operations."); @@ -130,20 +136,11 @@ public: Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = MatrixShape; - // w uint8; local_scale uint8; - constexpr static int kZippedRowsPerStages = - Shape::kK / 4 + (Shape::kK + 127) / 128; - - // code_scale float; code_zp float; super_scale ElementB - constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) + - sizeof_bits::value / 8; - - using ZippedShapeB = MatrixShape; - - using NopaddingShapeB = MatrixShape; + /// Shape of all quant params in shared memory + using QuantParamsShapeB = QuantParamsShape; public: // @@ -156,12 +153,8 @@ public: /// Buffer for B operand AlignedBuffer operand_B; - /// Buffer for quanted B operand - AlignedBuffer operand_zipped_B; - - /// Buffer for unzip B operand - AlignedBuffer - operand_unzip_B; + /// Buffer for extra quant params of B operand + AlignedBuffer operand_quant_params_B; public: // @@ -191,14 +184,6 @@ public: TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } - - CUTLASS_HOST_DEVICE - uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); } - - CUTLASS_HOST_DEVICE - typename Operator::ElementB *operand_unzip_B_ptr() { - return operand_unzip_B.data(); - } }; protected: diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h index 9531b01a7..dd26cf68e 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h @@ -45,7 +45,8 @@ #include "cutlass_extensions/arch/memory_copy_sm80.h" #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" +#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,15 +87,15 @@ template < typename Policy_, /// Number of stages, int Stages, + /// Accessor for extra quantized params + typename QuantParamsAccessor_, /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = bool> + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone> class Wint2xMmaMultistage : - public Wint2xMmaBase { + public Wint2xMmaBase { public: ///< Base class - using Base = Wint2xMmaBase; + using Base = Wint2xMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory @@ -107,8 +108,11 @@ public: using LayoutC = LayoutC_; ///< Policy describing tuning details using Policy = Policy_; + /// Accessor for extra quantized params + using QuantParamsAccessor = QuantParamsAccessor_; + using QuantArguments = typename QuantParamsAccessor::Arguments; - using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB; + static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK; using SmemIteratorA = SmemIteratorA_; using SmemIteratorB = SmemIteratorB_; @@ -129,6 +133,18 @@ public: /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; + //using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout; + using LayoutScale = layout::RowMajor; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpDequantizer = + warp::MmaTensorOpWin2xDequantizer; + static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed"); + /// Complex transform on A operand static ComplexTransform const kTransformA = Operator::kTransformA; @@ -174,18 +190,37 @@ public: using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using FragmentSuperScale = typename WarpDequantizer::FragmentSuperScale; + using FragmentCodeScaleZp = typename WarpDequantizer::FragmentCodeScaleZp; + using FragmentLocalScale = typename WarpDequantizer::FragmentLocalScale; + /// Temporary accumulator to facilitate staged-accumulation FragmentC tmp_accum_; /// Pair of A fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentA warp_loaded_frag_A_[2]; - WarpTransformedFragmentA warp_transformed_frag_A_[2]; + WarpTransformedFragmentA warp_frag_A_[2]; /// Pair of B fragments used to overlap shared memory loads and math instructions - WarpLoadedFragmentB warp_loaded_frag_B_[2]; - WarpTransformedFragmentB warp_transformed_frag_B_[2]; + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_frag_B_[2]; + + /// channel-wise quant params + FragmentCodeScaleZp warp_frag_code_scale_; + FragmentCodeScaleZp warp_frag_code_zp_; + FragmentSuperScale warp_frag_super_scale_; + + /// group-wise quant params + FragmentLocalScale warp_frag_local_scale_; }; + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool IsTileInterleaveLayout = + layout::IsColumnMajorTileInterleave::value; + static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); private: @@ -202,17 +237,18 @@ public: /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB smem_iterator_B_; + /// Accessor for extra quant params for B + QuantParamsAccessor quant_params_accessor_B_; + + // Wint2 unzip operator + WarpDequantizer warp_dequantizer_; + /// Shared memory write stage index int smem_write_stage_idx_; /// Shared memory read stage index int smem_read_stage_idx_; - uint8_t* column_wise_smem_ptr_B_; - - uint8_t* smem_zipped_ptr_B_; - int smem_zipped_bytes_per_stage_B_; - public: /// Construct from tensor references @@ -226,10 +262,15 @@ public: int warp_idx, ///< ID of each thread within a warp int lane_idx - ): - Base(shared_storage, thread_idx, warp_idx, lane_idx), + ) : Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx), + warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(), + quant_params_accessor_B_.local_scale_ref(), + quant_params_accessor_B_.code_scale_ref(), + quant_params_accessor_B_.code_zp_ref(), + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_write_stage_idx_(0), smem_read_stage_idx_(0) { @@ -250,11 +291,6 @@ public: {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); this->warp_tile_iterator_B_.add_tile_offset( {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); - - column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr(); - - smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn; - smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn; } /// Advance shared memory read-iterators to the next stage @@ -266,28 +302,22 @@ public: if (smem_read_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - // this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0}); smem_read_stage_idx_ = 0; } - this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); } /// Advance global memory read-iterators and shared memory write-iterators to the stage - template CUTLASS_DEVICE - void advance_smem_write_stage( - IteratorA &iterator_A, - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) + void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) { // Advance global iterators iterator_A.add_tile_offset({0, 1}); - //iterator_B.add_tile_offset({1, 0}); - tile_dequanter_B.AddTileOffset({1, 0}); + iterator_B.add_tile_offset({1, 0}); // Advance shared iterators smem_iterator_A_.add_tile_offset({0, 1}); - //smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_B_.add_tile_offset({1, 0}); // Increment shared memory write stage index ++smem_write_stage_idx_; @@ -295,7 +325,7 @@ public: if (smem_write_stage_idx_ == Base::kStages) { // Wrap back around to the 'start' of the circular buffer in shared memory smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - //smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); smem_write_stage_idx_ = 0; } } @@ -338,9 +368,14 @@ public: } } - template CUTLASS_DEVICE void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) { + if constexpr (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + } + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); this->smem_iterator_B_.set_iteration_index(group_start_B); @@ -360,13 +395,14 @@ public: CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B.get(); + bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false; if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, is_valid); } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, is_valid); } ++iterator_B; @@ -375,7 +411,6 @@ public: ++this->smem_iterator_B_; } } - __syncthreads(); } CUTLASS_DEVICE @@ -399,8 +434,6 @@ public: IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_A.get(), iterator_A.valid()); @@ -411,9 +444,12 @@ public: } } - template CUTLASS_DEVICE void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) { + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + iterator_B.set_iteration_index(0); this->smem_iterator_B_.set_iteration_index(0); @@ -433,35 +469,23 @@ public: IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - if (InitStage) { - cutlass::arch::copy_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - } else { - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { - cutlass::arch::copy_zfill( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } else { - cutlass::arch::copy( - dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - } + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); ++iterator_B; } ++this->smem_iterator_B_; } - __syncthreads(); } /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching /// the global fragments needed by the first kStages-1 threadblock mainloop iterations - template CUTLASS_DEVICE void prologue( IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining { // Issue several complete stages @@ -476,11 +500,18 @@ public: copy_tiles_and_advance_per_stage_A(iterator_A); // Async copy zipped B to shared memory. - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + copy_tiles_and_advance_per_stage_B(iterator_B); + + // Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale. + if (stage == 0) { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } else { + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); + } // Move to the next write stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); // Defines the boundary of a stage of cp.async. cutlass::arch::cp_async_fence(); @@ -510,6 +541,10 @@ public: ++last_smem_iterator_A; } + if (threadIdx.x >= IteratorB::ThreadMap::kThreads) { + return; + } + /// Iterator to write threadblock-scoped tile of B operand to shared memory SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); typename IteratorB::AccessType zero_B; @@ -542,57 +577,57 @@ public: } /// Perform a threadblock mainloop iteration of matrix multiply-accumulate - template CUTLASS_DEVICE void mac_loop_iter( PipeState &pipe_state, ///< [in|out] loop-carried pipeline state FragmentC &accum, ///< [in|out] destination accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory - TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand - int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining + QuantArguments &mma_quant_args, ///< iterators for extra quant params for B + int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining int stage) { + const int mma_stage = stage - Base::kStages + 1; + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration CUTLASS_PRAGMA_UNROLL for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { - // CUTLASS_TRACE_DEVICE(" [MMa] stage=%d, warp_mma_k=%d", stage, warp_mma_k); + + int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB; + + if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) { + // Load the next warp-tile's B fragment from shared memory + this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + } + + // load next-tile of group-wise local_scale from shared memory + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); + } // Load the next warp-tile's A fragment from shared memory this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - // Unpack and dequant the first stage of B. - int unpack_stage = stage - Base::kStages + 2; - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_ + (unpack_stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, unpack_stage); - - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); - } - - // Load the next warp-tile's B fragment from shared memory - this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_B_; - - // Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary - if (warp_mma_k > 0) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], - pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]); - } + // dequantizes next warp-tile + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2], + ((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK, + (warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB); // Execute the current warp-tile of MMA operations - if (Detail::kStagedAccumulation) { + if constexpr (Detail::kStagedAccumulation) { warp_mma_( pipe_state.tmp_accum_, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], pipe_state.tmp_accum_ ); @@ -604,22 +639,22 @@ public: } else { warp_mma_( accum, - pipe_state.warp_transformed_frag_A_[warp_mma_k % 2], - pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], - accum - ); + pipe_state.warp_frag_A_[warp_mma_k % 2], + pipe_state.warp_frag_B_[warp_mma_k % 2], + accum); } // Except for the last warp-tile, all warp-tiles issue their share of // global->shared fragment copies if (warp_mma_k < Base::kWarpGemmIterations - 1) { int group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); if (warp_mma_k == 0) { - tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_, - column_wise_smem_ptr_B_, stage); + quant_params_accessor_B_.copy_tiles_and_advance_per_stage(mma_quant_args, stage); } } @@ -628,9 +663,15 @@ public: // - moves to the next global fetch stage if (warp_mma_k + 2 == Base::kWarpGemmIterations) { // Performs the last warp-tile's share of global->shared fragment copies - int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) { + int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + } - copy_tiles_and_advance_A(iterator_A, group_start_iteration_A); + if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) { + int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + copy_tiles_and_advance_B(iterator_B, group_start_iteration_B); + } // Inserts a memory fence between stages of cp.async instructions. cutlass::arch::cp_async_fence(); @@ -639,69 +680,66 @@ public: gmem_wait(); // Move to the next global fetch stage - advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B); + advance_smem_write_stage(iterator_A, iterator_B); + quant_params_accessor_B_.advance_smem_write_stage(mma_quant_args); + advance_smem_read_stage(); + int byte_offset = quant_params_accessor_B_.advance_smem_read_stage(); + warp_dequantizer_.add_pointer_offset(byte_offset); // Disable global fetching when done with global fetch iterations --gemm_k_iterations; iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - } - - // The last warp-tile also converts the shared memory fragments used by - // the first warp-tile of the next iteration, if necessary (so we can - // immediately start issuing MMA instructions at the top of the loop ) - if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2], - pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]); + iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); } } } /// Perform the specified number of threadblock mainloop iterations of matrix /// multiply-accumulate. Assumes prologue has been initiated. - template CUTLASS_DEVICE void gemm_iters( int gemm_k_iterations, ///< number of threadblock mainloop iterations FragmentC &accum, ///< [in|out] accumulator tile IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory - IteratorB &iterator_B, - TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + QuantArguments &mma_quant_args) { PipeState pipe_state; - // Unpack and dequant the first stage of B. - tile_dequanter_B.UnpackAndDequant(smem_zipped_ptr_B_, column_wise_smem_ptr_B_, 0); - // Disable global fetching if done with global fetch iterations iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1)); - - // Load first warp-tile's A fragment from shared memory - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); - ++this->warp_tile_iterator_A_; - - // Copy dequatized data to shared memory used by mma core. - copy_tiles_and_advance_per_stage_B(iterator_B); + iterator_B.clear_mask(gemm_k_iterations == 0); + quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0); // Load first warp-tile's B fragment from shared memory this->warp_tile_iterator_B_.set_kgroup_index(0); - this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); ++this->warp_tile_iterator_B_; - // Transform, if necessary, the first warp-tile's shared memory fragments - warp_mma_.transform( - pipe_state.warp_transformed_frag_A_[0], - pipe_state.warp_transformed_frag_B_[0], - pipe_state.warp_loaded_frag_A_[0], - pipe_state.warp_loaded_frag_B_[0]); + warp_dequantizer_.load(pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_); - if (Detail::kStagedAccumulation) { + warp_dequantizer_.load(pipe_state.warp_frag_local_scale_); + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + // Dequantize B to in register + warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_, + pipe_state.warp_frag_code_scale_, + pipe_state.warp_frag_code_zp_, + pipe_state.warp_frag_super_scale_, + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_frag_B_[0], + 0, + 0); + + if constexpr (Detail::kStagedAccumulation) { pipe_state.tmp_accum_.clear(); } @@ -715,13 +753,13 @@ public: accum, iterator_A, iterator_B, - tile_dequanter_B, + mma_quant_args, gemm_k_iterations, stage); stage += 1; } - if (Detail::kStagedAccumulation) { + if constexpr (Detail::kStagedAccumulation) { plus plus_accum; accum = plus_accum(accum, pipe_state.tmp_accum_); } @@ -761,14 +799,12 @@ public: else { this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)}); - //this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); - this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0}); + this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0}); } smem_read_stage_idx_ = smem_write_stage_idx_; } /// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory. - template CUTLASS_DEVICE void operator()( ///< problem size of GEMM @@ -779,13 +815,13 @@ public: IteratorA iterator_A, ///< iterator over B operand in global memory IteratorB iterator_B, - ///< pre-load and dequantize B to shared memory - TileDequanterB tile_dequanter_B, + ///< iterators for extra quant params for B + QuantArguments mma_quant_args, ///< initial value of accumulator FragmentC const &src_accum) { // Prologue (start fetching iterations of global fragments into shared memory) - prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations); + prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations); // Wait until we have at least one completed global fetch stage gmem_wait(); @@ -794,7 +830,7 @@ public: accum = src_accum; // Perform the MAC-iterations - gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args); } }; diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h new file mode 100644 index 000000000..c6eb2750c --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h @@ -0,0 +1,315 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/trace.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Original data type + typename T, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterators over super scales in global memory + typename IteratorSuperScale_, + /// Iterators over super scales in shared memory + typename SmemIteratorSuperScale_, + /// Iterators over local scales in global memory + typename IteratorLocalScale_, + /// Iterators over local scales in shared memory + typename SmemIteratorLocalScale_, + /// Iterators over code scales and zps in global memory + typename IteratorCodeScaleZp_, + /// Iterators over code scales and zps in shared memory + typename SmemIteratorCodeScaleZp_, + /// Number of stages, + int Stages_, + /// Group size for quantization + int GroupSize_> +class Wint2ParamsAccessor { +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + using ElementType = T; + using Shape = Shape_; + + using IteratorSuperScale = IteratorSuperScale_; + using SmemIteratorSuperScale = SmemIteratorSuperScale_; + + using IteratorLocalScale = IteratorLocalScale_; + using SmemIteratorLocalScale = SmemIteratorLocalScale_; + + using IteratorCodeScaleZp = IteratorCodeScaleZp_; + using SmemIteratorCodeScaleZp = SmemIteratorCodeScaleZp_; + + constexpr static int kStages = Stages_; + constexpr static int kGroupSize = GroupSize_; + + using ElementSuperScale = typename IteratorSuperScale::Element; + using LayoutSuperScale = typename IteratorSuperScale::Layout; + + /// local_scale uint4 and group-wise + using ElementLocalScale = typename IteratorLocalScale::Element; + using LayoutLocalScale = typename IteratorLocalScale::Layout; + static_assert(platform::is_same::value, + "local_scale's type must be uint4b_t."); + + using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element; + using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout; + + /// 2 uint4b_t values are stored in a single uint8_t + constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK; + constexpr static int kLocalScaleRows = + IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits::value / 8 / Shape::kN; + + using SmemElement = uint8_t; + constexpr static int kSmemRows = + kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2; + constexpr static int kSmemColumns = Shape::kN; + + using QuantParamsShape = MatrixShape; + + constexpr static int kSuperScaleSmemOffset = 0; + constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale); + constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp); + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + + struct Arguments { + IteratorSuperScale iterator_super_scale; + IteratorLocalScale iterator_local_scale; + IteratorCodeScaleZp iterator_code_scale; + IteratorCodeScaleZp iterator_code_zp; + + int local_scale_pointer_offset; + + CUTLASS_DEVICE + Arguments(IteratorSuperScale iterator_super_scale, + IteratorLocalScale iterator_local_scale, + IteratorCodeScaleZp iterator_code_scale, + IteratorCodeScaleZp iterator_code_zp, + int local_scale_pointer_offset) + : iterator_super_scale(iterator_super_scale), + iterator_local_scale(iterator_local_scale), + iterator_code_scale(iterator_code_scale), + iterator_code_zp(iterator_code_zp), + local_scale_pointer_offset(local_scale_pointer_offset) {} + }; + +private: + // + // Data members + // + + /// Begin address of shared memory + uint8_t* smem_pointer_; + + /// Iterator to write threadblock-scoped tile of super scale operand to shared memory + SmemIteratorSuperScale smem_iterator_super_scale_; + /// Iterator to write threadblock-scoped tile of local scale operand to shared memory + SmemIteratorLocalScale smem_iterator_local_scale_; + /// Iterator to write threadblock-scoped tile of code scale operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_scale_; + /// Iterator to write threadblock-scoped tile of code zp operand to shared memory + SmemIteratorCodeScaleZp smem_iterator_code_zp_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + CUTLASS_DEVICE + ElementSuperScale* get_super_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kSuperScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementLocalScale* get_local_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kLocalScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_scale_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeScaleSmemOffset); + } + + CUTLASS_DEVICE + ElementCodeScaleZp* get_code_zp_smem_ptr() { + return reinterpret_cast(smem_pointer_ + kCodeZpSmemOffset); + } + +public: + /// Construct from tensor references + CUTLASS_DEVICE + Wint2ParamsAccessor( + ///< prointer of shared memory + uint8_t* smem_pointer, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : smem_pointer_(smem_pointer), + smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn), + get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx), + smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn), + get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx), + smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn), + get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) {} + + CUTLASS_DEVICE + SuperTensorRef super_scale_ref() { + return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + LocalTensorRef local_scale_ref() { + return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_scale_ref() { + return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + CUTLASS_DEVICE + CodeTensorRef code_zp_ref() { + return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)}; + } + + template + CUTLASS_DEVICE + void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) { + if constexpr (IsFirstStage) { + // Load channel-wise super_scale to shared memory, which only needs to be done once. + typename IteratorSuperScale::Fragment tb_frag_super_scale; + tb_frag_super_scale.clear(); + quant_args.iterator_super_scale.load(tb_frag_super_scale); + this->smem_iterator_super_scale_.store(tb_frag_super_scale); + + // Load channel-wise code_scale to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_scale; + tb_frag_code_scale.clear(); + quant_args.iterator_code_scale.load(tb_frag_code_scale); + this->smem_iterator_code_scale_.store(tb_frag_code_scale); + + // Load channel-wise code_zp to shared memory, which only needs to be done once. + typename IteratorCodeScaleZp::Fragment tb_frag_code_zp; + tb_frag_code_zp.clear(); + quant_args.iterator_code_zp.load(tb_frag_code_zp); + this->smem_iterator_code_zp_.store(tb_frag_code_zp); + } + + if ((stage % kStagesPerLocalScaleLoad) == 0) { + // Load group-wise local_scale to shared memory, which only needs to be done at each stage. + // Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages. + using AccessType = typename IteratorLocalScale::AccessType; + cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits::value == 128) + ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always; + + quant_args.iterator_local_scale.set_iteration_index(0); + this->smem_iterator_local_scale_.set_iteration_index(0); + + // Async Copy for local_scale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) { + AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_local_scale_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) { + auto gmem_ptr = quant_args.iterator_local_scale.get(); + + int const kSrcBytes = + sizeof_bits::value * + IteratorLocalScale::ThreadMap::kElementsPerAccess / + IteratorLocalScale::kAccessesPerVector / 8; + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid()); + } + ++quant_args.iterator_local_scale; + } + ++this->smem_iterator_local_scale_; + } + } + + CUTLASS_DEVICE + void advance_smem_write_stage(Arguments &quant_args) { + if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + // Advance global iterators + quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset); + + // Advance shared iterators + int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn; + smem_iterator_local_scale_.add_pointer_offset(pointer_offset); + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + int advance_smem_read_stage() { + int byte_offset = 0; + + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ % kStagesPerLocalScaleLoad == 0) { + byte_offset = kLocalScaleRows * kSmemColumns; + } + + if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) { + smem_read_stage_idx_ = 0; + byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns; + } + + return byte_offset; + } + + CUTLASS_DEVICE + int clear_mask(Arguments &quant_args, bool cond) { + quant_args.iterator_local_scale.clear_mask(cond); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h deleted file mode 100644 index cec6bcea0..000000000 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "cutlass/gemm_coord.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/threadblock/wint2x_unzip.h" - -namespace cutlass { -namespace gemm { -namespace threadblock { - -template -struct TileDequanter { - using WeightQuantTraits = WintQuantTraits; - using MmaElementT = typename WeightQuantTraits::MmaWeightType; - using QuantArguments = typename WeightQuantTraits::Arguments; - - using UnzipAndDequantFunctor = - UnzipAndDequantFunctor; - - static constexpr bool kUseSharedMemory = true; - - static constexpr int kRows = Rows; - static constexpr int kColumns = Columns; - static constexpr int kStages = Stages; - - MmaElementT *out_smem_ptr{nullptr}; - - char *pointer{nullptr}; - int64_t ldm{0}; - cutlass::MatrixCoord tb_offset; - cutlass::MatrixCoord extent; - - ScaleElementT *super_scale_ptr{nullptr}; - cutlass::MatrixCoord tb_offset_scale; - - QuantArguments quant_args; - - int64_t block_start_rows[kStages]; - bool need_preload{true}; - UnzipAndDequantFunctor unzip_functor; - - CUTLASS_DEVICE - TileDequanter(MmaElementT *out_smem_ptr, char *pointer, int64_t ldm, - const cutlass::MatrixCoord &extent, - const cutlass::MatrixCoord &tb_offset, - ScaleElementT *super_scale_ptr, - const cutlass::MatrixCoord &tb_offset_scale, - const QuantArguments &quant_args) - : out_smem_ptr(out_smem_ptr), pointer(pointer), ldm(ldm), extent(extent), - tb_offset(tb_offset), super_scale_ptr(super_scale_ptr), - tb_offset_scale(tb_offset_scale), quant_args(quant_args) {} - - CUTLASS_DEVICE - MmaElementT *GetOutPtr() { return out_smem_ptr; } - - CUTLASS_DEVICE - void AddTileOffset(const cutlass::MatrixCoord &tile_offset) { - tb_offset.row() += tile_offset.row() * kRows; - tb_offset.column() += tile_offset.column() * kColumns; - tb_offset_scale.column() += tile_offset.column() * kColumns; - } - - CUTLASS_DEVICE - void Load(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int zipped_row = WeightQuantTraits::CaclPackedDim(tb_offset.row()); - if (tb_offset.row() >= extent.row() || - tb_offset.column() >= extent.column()) { - return; - } - - block_start_rows[stage % kStages] = tb_offset.row(); - - using ZippedT = typename WeightQuantTraits::WeightType; - ZippedT *in_ptr = reinterpret_cast(pointer) + zipped_row * ldm + - tb_offset.column(); - ScaleElementT *scale_ptr = super_scale_ptr + tb_offset_scale.column(); - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - const uint8_t *local_scale_ptr = quant_args.local_scale_ptr + - (tb_offset.row() / 128) * ldm + - tb_offset_scale.column(); - const float *code_scale_ptr = - quant_args.code_scale_ptr + tb_offset_scale.column(); - const float *code_zp_ptr = - quant_args.code_zp_ptr + tb_offset_scale.column(); - - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.LoadAsync(in_ptr, local_scale_ptr, code_scale_ptr, code_zp_ptr, - scale_ptr, &args, ldm, need_preload); - need_preload = false; - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } - - CUTLASS_DEVICE - void UnpackAndDequant(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr, int stage) { - int64_t block_start_row = block_start_rows[stage % kStages]; - if (block_start_row >= extent.row()) { - return; - } - - if constexpr (Method == WintQuantMethod::kWeightOnlyInt2) { - typename UnzipAndDequantFunctor::Arguments args(zipped_smem_ptr, column_wise_smem_ptr); - unzip_functor.ComputeVectorized(args, out_smem_ptr, block_start_row); - } else { - // CUTLASS_TRACE_DEVICE("Not Supported!"); - } - } -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h index 350b247de..af4298df5 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -41,12 +41,9 @@ #include "cutlass_extensions/arch/mma.h" #include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,7 +78,7 @@ private: // Shape for computing the FP16s using ComputeInstructionShape = InstructionShape_; - // Chosen so we get K=16 for int8 and K=32 for int4. + // Chosen so we get K=16 for int8, K=32 for int4, K=64 for int2. static constexpr int LoadInstructionK = 128 / sizeof_bits::value; // Shape for loading the narrow data type from shared memory diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 7c5088894..64136a975 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -58,15 +58,12 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ +namespace cutlass { +namespace gemm { +namespace warp { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer. template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, @@ -297,6 +294,235 @@ public: } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting Tensor Cores, for the case when A is floating point and B is quantized integer. +/// Specialization for B of uint2b_t. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +class MmaTensorOpComputeBWithF16< + Shape_, + ElementA_, + LayoutA_, + uint2b_t, + LayoutB_, + ElementC_, + LayoutC_, + Policy_, + SharedMemoryInstructionShape_, + PartitionsK_, + AccumulatorsInRowMajor> +{ +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = uint2b_t; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA + = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h new file mode 100644 index 000000000..4678b58e4 --- /dev/null +++ b/custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h @@ -0,0 +1,442 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +namespace detail { + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits { + using Type = __nv_bfloat16; + using DualType = __nv_bfloat162; +}; + +template <> +struct DataTypeTraits { + using Type = __half; + using DualType = __half2; +}; + +template +struct LocalScaleConverter { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t kLocalScaleMask = 0xf; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + int32_t shifted_value = (static_cast(local_scale_frag[i]) >> shift_bit) & kLocalScaleMask; + scale_frag[i] = static_cast(shifted_value) * super_scale_frag[i]; + } + } +}; + +template +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { + constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + constexpr uint32_t MASK = 0x000f000f; + // 2^10 = 1024 + constexpr uint32_t I4s_TO_FP16s_MAGIC_NUM = 0x64006400; + + // -2^10 = -1024 + constexpr uint32_t FP16_BIAS = 0xE400E400; + // 1.0 + constexpr uint32_t FP16_ONE = 0x3C003C00; + + __half2* scale_ptr = reinterpret_cast<__half2 *>(&scale_frag); + __half2 const* super_scale_ptr = reinterpret_cast<__half2 const*>(&super_scale_frag); + + uint32_t const* local_scale_ptr = reinterpret_cast(&local_scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; + + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_FP16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_FP16s_MAGIC_NUM); + + __half2 scale0 = __hfma2(*reinterpret_cast<__half2*>(&unpack0), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); + __half2 scale1 = __hfma2(*reinterpret_cast<__half2*>(&unpack1), + *reinterpret_cast(&FP16_ONE), + *reinterpret_cast(&FP16_BIAS)); + + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); + } + } +}; + +template +struct LocalScaleConverter::type> { + using FragmentSource = Array; + using FragmentResult = Array; + + CUTLASS_DEVICE + static void Apply(FragmentSource const& local_scale_frag, + FragmentResult const& super_scale_frag, + FragmentResult& scale_frag, + int shift_bit) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + constexpr uint32_t MASK = 0x000F000F; + constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + constexpr uint32_t BF16_BIAS = 0xC300C300; + constexpr uint32_t BF16_ONE = 0x3F803F80; + + __nv_bfloat162* scale_ptr = reinterpret_cast<__nv_bfloat162 *>(&scale_frag); + __nv_bfloat162 const* super_scale_ptr = reinterpret_cast<__nv_bfloat162 const*>(&super_scale_frag); + + uint32_t const* local_scale_ptr = reinterpret_cast(&local_scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + int i4s = local_scale_ptr[i] >> shift_bit; + + // unpack: 0, 1 + int32_t low = __byte_perm(i4s, i4s, 0xF1F0); + int32_t unpack0 = lop3(low, MASK, I4s_TO_BF16s_MAGIC_NUM); + // unpack: 2, 3 + int32_t high = __byte_perm(i4s, i4s, 0xF3F2); + int32_t unpack1 = lop3(high, MASK, I4s_TO_BF16s_MAGIC_NUM); + + nv_bfloat162 scale0 = __hfma2(*reinterpret_cast(&unpack0), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); + nv_bfloat162 scale1 = __hfma2(*reinterpret_cast(&unpack1), + *reinterpret_cast(&BF16_ONE), + *reinterpret_cast(&BF16_BIAS)); + + scale_ptr[2 * i] = __hmul2(scale0, super_scale_ptr[2 * i]); + scale_ptr[2 * i + 1] = __hmul2(scale1, super_scale_ptr[2 * i + 1]); + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } +}; + +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename ElementOperand_, + /// Layout of operand + typename Layout_, + /// Group size for quantization + int GroupSize_, + /// + typename Enable = void> +class MmaTensorOpWin2xDequantizer { + //static_assert(false, "Not Supported!"); +}; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// Data type of Scale elements + typename ElementOperand_, + /// Group size for quantization + int GroupSize_> +class MmaTensorOpWin2xDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + ElementOperand_, + layout::RowMajor, + GroupSize_> + //typename platform::enable_if= 80 + // && platform::is_same::value>::type> +{ +public: + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Warp mma shape + using Shape = Shape_; + + /// Type of mma operand + using ElementOperand = ElementOperand_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// Group size for quantization + static constexpr int kGroupSize = GroupSize_; + + /// Type of input + using ElementB = typename MmaOperator::FragmentB::Element; + static_assert(platform::is_same::value, "ElementB must be uint2b_t"); + + /// Type of the scales + using ElementLocalScale = uint4b_t; + using ElementSuperScale = ElementOperand; + using ElementCodeScaleZp = float; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kWarpIterationsAlongN = MmaOperator::MmaIterations::kColumn; + + // use uint8_t to save 2 4-bits local scales + using FragmentLocalScale = Array; + using FragmentSuperScale = Array; + using FragmentCodeScaleZp = Array; + + /// Fragment to hold B data before Mma + using FragmentInput = Array; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + static constexpr int kNumPacks = sizeof_bits::value / sizeof_bits::value; + static constexpr int kUnpackFactor = MmaOperator::FragmentB::kElements / (kWarpIterationsAlongN * kNumPacks); + static constexpr int kUnpackInterval = kExpansionFactor / kUnpackFactor; + + /// Unpack 4 uint2b_t values compreseed in a uint8_t to floating points. + using Uint2Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementOperand, ElementB, MmaOperator::FragmentB::kElements / kUnpackFactor>; + using FragmentInputUnpack = typename Uint2Converter::result_type; + + /// Fragment to hold internal scales before Mma + using FragmentScale = Array; + + /// Fragment of dequantized B + using FragmentOutput = Array; + + /// TensorRef type for loading element from a tensor + using SuperTensorRef = cutlass::TensorRef; + using LocalTensorRef = cutlass::TensorRef; + using CodeTensorRef = cutlass::TensorRef; + +private: + // + // Data members + // + + uint8_t* pointer_local_scale_; + ElementCodeScaleZp* pointer_code_scale_; + ElementCodeScaleZp* pointer_code_zp_; + ElementSuperScale* pointer_super_scale_; + + //FragmentInputUnpack unpacked_frag_; + FragmentScale scale_frag_; + +public: + CUTLASS_DEVICE + MmaTensorOpWin2xDequantizer(SuperTensorRef smem_super_scale, + LocalTensorRef smem_local_scale, + CodeTensorRef smem_code_scale, + CodeTensorRef smem_code_zp, + int warp_idx_n, + int lane_idx) { + int warp_offset = warp_idx_n * Shape::kN; + int quad = lane_idx / 4; + int thread_offset = warp_offset + quad; + pointer_super_scale_ = smem_super_scale.data() + thread_offset; + pointer_code_scale_ = smem_code_scale.data() + thread_offset; + pointer_code_zp_ = smem_code_zp.data() + thread_offset; + pointer_local_scale_ = reinterpret_cast(smem_local_scale.data()) + thread_offset; + } + + /// Channel-wise params, need to load just once + CUTLASS_DEVICE + void load(FragmentCodeScaleZp& code_scale_frag, + FragmentCodeScaleZp& code_zp_frag, + FragmentSuperScale& super_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + super_scale_frag[mma_n_iter] = pointer_super_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + code_scale_frag[mma_n_iter] = pointer_code_scale_[mma_n_iter * InstructionShape::kN]; + code_zp_frag[mma_n_iter] = pointer_code_zp_[mma_n_iter * InstructionShape::kN]; + } + } + + /// Group-wise params, need to load multiple times + CUTLASS_DEVICE + void load(FragmentLocalScale& local_scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + local_scale_frag[mma_n_iter] = pointer_local_scale_[mma_n_iter * InstructionShape::kN]; // bank conflict + } + } + + CUTLASS_DEVICE + void dequantize(const FragmentLocalScale& local_scale_frag, + const FragmentCodeScaleZp& code_scale_frag, + const FragmentCodeScaleZp& code_zp_frag, + const FragmentSuperScale& super_scale_frag, + const FragmentInput& input_frag, + FragmentOutput& output_frag, + int tb_offset_k, + int warp_k_compute_offset) { + if constexpr (kUnpackInterval != 1) { + // unsupport now + arch::device_breakpoint(); + } + + typename Uint2Converter::source_type source_frag; + + int in_offset = warp_k_compute_offset * kUnpackInterval; + + uint8_t const* ptr_input = reinterpret_cast(&input_frag); + uint8_t* ptr_source = reinterpret_cast(&source_frag); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; ++mma_n_iter) { + ptr_source[mma_n_iter] = ptr_input[mma_n_iter * kUnpackFactor + in_offset]; + } + FragmentInputUnpack unpacked_frag = Uint2Converter::convert(source_frag, code_scale_frag, code_zp_frag); + + // dequantize local_scale + if (warp_k_compute_offset == 0) { + using LocalScaleConverter = detail::LocalScaleConverter; + + // special for TileRows = 64 + int local_scale_shift = (((tb_offset_k / kGroupSize) + 1) & 1) * 4; + LocalScaleConverter::Apply(local_scale_frag, super_scale_frag, scale_frag_, local_scale_shift); + } + + // unscale + // After applying LOP3 optimizations for performance, the B operand requires data rearrangement. + // reorder: [0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15] + const int kWarpIterationsAlongK = FragmentOutput::kElements / kWarpIterationsAlongN; + + using Type = typename detail::DataTypeTraits::Type; + using DualType = typename detail::DataTypeTraits::DualType; + + Type* output_ptr = reinterpret_cast(&output_frag); + DualType const* unpacked_ptr = reinterpret_cast(&unpacked_frag); + DualType const* scale_ptr = reinterpret_cast(&scale_frag_); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < kWarpIterationsAlongN; mma_n_iter += 2) { + int mapped_idx_base = (mma_n_iter / 2) * kWarpIterationsAlongK; + + DualType scalex2 = scale_ptr[mma_n_iter / 2]; + + CUTLASS_PRAGMA_UNROLL + for (int mma_k_iter = 0; mma_k_iter < kWarpIterationsAlongK; ++mma_k_iter) { + DualType unpacked_valuex2 = unpacked_ptr[mapped_idx_base + mma_k_iter]; + DualType scaled_value = __hmul2(unpacked_valuex2, scalex2); + output_ptr[mma_n_iter * kWarpIterationsAlongK + mma_k_iter] = scaled_value.x; + output_ptr[(mma_n_iter + 1) * kWarpIterationsAlongK + mma_k_iter] = scaled_value.y; + } + } + } + + /// Add an offset to pointer in units of elements. + /// Only group-wise params needs. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + pointer_local_scale_ += offset; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h index 44ba79680..e7e17657b 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h +++ b/custom_ops/gpu_ops/cutlass_extensions/interleaved_numeric_conversion.h @@ -39,18 +39,25 @@ #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" +#include "cutlass/trace.h" -namespace cutlass -{ +namespace cutlass { + +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; +struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter @@ -440,6 +447,329 @@ struct FastInterleavedAndBiasedNumericArrayConverter } }; +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = float; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) + { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; + + decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) + { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + + decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) + { + result_type result; + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + + static constexpr uint32_t MASK = 0x003F003F; + // 2^10 = 1024 + static constexpr uint32_t EX = 0x64006400; + + uint32_t* h = reinterpret_cast(&result); + + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); + + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); + + // 1024 + 32 = 1056 + static constexpr uint32_t SUB = 0x64206420; + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) + { + return convert(s, code_scale, code_zp); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + using ScaleComputeT = float; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) + { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + ScaleComputeT new_code_zp = code_zp + 0.5f; + + decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); + decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); + decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); + decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) + { + uint32_t const i8s = reinterpret_cast(source); + + // 2^23 = 8388608 + static constexpr uint32_t FP32_BASE = 0x4B000000; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); + + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); + asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); + + int32_t decode_value[4]; + + decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); + decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); + decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); + decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); + + return convert_impl(decode_value); + } + + CUTLASS_DEVICE + static result_type convert_impl(int32_t* decode_value) + { + result_type result; + + static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; + static constexpr uint32_t MASK = 0x003F003F; + // 2^7 = 128 + static constexpr uint32_t EX = 0x43004300; + + uint32_t* h = reinterpret_cast(&result); + + int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); + int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); + + h[0] = lop3(q0 >> 9, MASK, EX); + h[1] = lop3(q0 >> 6, MASK, EX); + h[2] = lop3(q0 >> 3, MASK, EX); + h[3] = lop3(q0, MASK, EX); + + h[4] = lop3(q1 >> 9, MASK, EX); + h[5] = lop3(q1 >> 6, MASK, EX); + h[6] = lop3(q1 >> 3, MASK, EX); + h[7] = lop3(q1, MASK, EX); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16)) + // 128 + 32 = 160 + static constexpr uint32_t SUB = 0x43204320; + + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); + + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); + asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); +#else + // 1.0 + static constexpr uint32_t MUL = 0x3F803F80; + // -160 + static constexpr uint32_t ADD = 0xC320C320; + + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD)); + + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD)); + asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD)); +#endif + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) + { + return convert(s, code_scale, code_zp); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static_assert(platform::is_same::value || platform::is_same::value, + "T must be fp16 or bf16"); + + static constexpr int kVecWidth = 16; + static_assert(!(N % kVecWidth), "N must be multiple of 16."); + + using result_type = Array; + using source_type = Array; + using code_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]); + } + + return result; + } + + CUTLASS_DEVICE + static result_type convert(source_type const& source, Array const& code_scale, Array const& code_zp) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + result_type result; + using vec_result = typename Converter::result_type; + using vec_source = typename Converter::source_type; + using vec_code = typename Converter::code_type; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + vec_code const* code_scale_ptr = reinterpret_cast(&code_scale); + vec_code const* code_zp_ptr = reinterpret_cast(&code_zp); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / kVecWidth; ++i) + { + result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp) + { + return convert(s, code_scale, code_zp); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h index 9e1c6c463..fa2881069 100644 --- a/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h +++ b/custom_ops/gpu_ops/cutlass_extensions/wint_type_traits.h @@ -125,10 +125,13 @@ struct WintQuantTraits { static constexpr int32_t kNumPackedValues = 4; static constexpr int32_t kPackedSize = 16; + using LocalScaleType = uint4b_t; + using CodeScaleZpType = float; + struct Arguments { - const uint8_t *local_scale_ptr; // quanted 4-bits - const float *code_scale_ptr; - const float *code_zp_ptr; + uint8_t *local_scale_ptr; // quanted 4-bits + float *code_scale_ptr; + float *code_zp_ptr; }; CUTLASS_DEVICE diff --git a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h index 356f30596..54a144974 100644 --- a/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h +++ b/custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h @@ -43,7 +43,6 @@ #include "cutlass/trace.h" #include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h" #include "cutlass_extensions/tile_interleaved_layout.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -775,17 +774,54 @@ struct Wint2xMoeFCGemm : public MoeFCGemm struct KernelRunner { using WeightQuantTraits = WintQuantTraits; - using QuantArguments = typename WeightQuantTraits::Arguments; + using MmaQuantArguments = typename Mma::QuantParamsAccessor::Arguments; CUTLASS_DEVICE - static QuantArguments get_quant_args(Params const& params, int32_t problem_idx, const int64_t gemm_k, const int64_t gemm_n) { - QuantArguments quant_args; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - quant_args.local_scale_ptr = params.local_scale + problem_idx * gemm_k * gemm_n / 128; - quant_args.code_scale_ptr = params.code_scale + problem_idx * gemm_n; - quant_args.code_zp_ptr = params.code_zp + problem_idx * gemm_n; - } - return quant_args; + static MmaQuantArguments prepare_quant_args( + Params const& params, cutlass::gemm::GemmCoord const& threadblock_offset, + int64_t problem_idx, const int32_t gemm_k, const int32_t gemm_n, const int thread_idx) { + // the begin threadblock_offset of scale, which holds the same column id with C, but with no row id + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + cutlass::MatrixCoord tb_offset_local_scale{0, threadblock_offset.n() * 2}; + + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorSuperScale iterator_super_scale( + Mma::QuantParamsAccessor::LayoutSuperScale(gemm_n), + weight_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + int local_scale_pointer_offset = ((ThreadblockShape::kK + 127) / 128) * (gemm_n * 2); + int64_t offset_in_bytes = problem_idx * gemm_k * gemm_n / 128; + uint4b_t *local_scale_ptr = reinterpret_cast(params.local_scale + offset_in_bytes); + + typename Mma::QuantParamsAccessor::IteratorLocalScale iterator_local_scale( + Mma::QuantParamsAccessor::LayoutLocalScale(gemm_n * 2), + local_scale_ptr, + {(gemm_k + 127) / 128, gemm_n * 2}, + thread_idx, + tb_offset_local_scale); + + float* code_scale_ptr = params.code_scale + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_scale( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_scale_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + float* code_zp_ptr = params.code_zp + problem_idx * gemm_n; + typename Mma::QuantParamsAccessor::IteratorCodeScaleZp iterator_code_zp( + Mma::QuantParamsAccessor::LayoutCodeScaleZp(gemm_n), + code_zp_ptr, + {1, gemm_n}, + thread_idx, + tb_offset_scale); + + MmaQuantArguments mma_quant_args( + iterator_super_scale, iterator_local_scale, iterator_code_scale, iterator_code_zp, local_scale_pointer_offset); + return mma_quant_args; } CUTLASS_DEVICE @@ -814,9 +850,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm= 1, "B must be row major/col major OR col major interleaved."); - // LayoutB should be RowMajor - using TileDequanterB = cutlass::gemm::threadblock::TileDequanter; - // // Problem visitor. // @@ -843,12 +876,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(byte_ptr_B); typename LayoutB::LongIndex ldm_B = platform::is_same::value ? gemm_n : gemm_k * kInterleave; - typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns; // the begin threadblock_offset of B, which holds the same column id with C - cutlass::MatrixCoord tb_offset_B{0, - threadblock_offset.n() / kInterleave}; - + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave}; - cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns}; - - MmaElementB* smem_unzip_B_ptr = nullptr; - if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) { - smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr(); - } - QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n); - TileDequanterB tile_dequanter_B(smem_unzip_B_ptr, - byte_ptr_B, - ldm_B, - extent_B, - tb_offset_B, - weight_scale_ptr, - tb_offset_scale, - quant_args); - MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr(); // Compute position within threadblock int thread_idx = threadIdx.x; @@ -914,20 +919,21 @@ struct Wint2xMoeFCGemm : public MoeFCGemm(A), - reinterpret_cast(B), + reinterpret_cast(B), reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), diff --git a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu index 5a68c9e2f..f3e51bfcf 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu @@ -49,12 +49,13 @@ void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input, typename WeightOnlyTraits::Arguments up_gate_proj_quant_args; typename WeightOnlyTraits::Arguments down_proj_quant_args; if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) { - up_gate_proj_quant_args.local_scale_ptr = up_gate_proj_local_scale->data(); - up_gate_proj_quant_args.code_scale_ptr = up_gate_proj_code_scale->data(); - up_gate_proj_quant_args.code_zp_ptr = up_gate_proj_code_zp->data(); - down_proj_quant_args.local_scale_ptr = down_proj_local_scale->data(); - down_proj_quant_args.code_scale_ptr = down_proj_code_scale->data(); - down_proj_quant_args.code_zp_ptr = down_proj_code_zp->data(); + up_gate_proj_quant_args.local_scale_ptr = const_cast(up_gate_proj_local_scale->data()); + up_gate_proj_quant_args.code_scale_ptr = const_cast(up_gate_proj_code_scale->data()); + up_gate_proj_quant_args.code_zp_ptr = const_cast(up_gate_proj_code_zp->data()); + + down_proj_quant_args.local_scale_ptr = const_cast(down_proj_local_scale->data()); + down_proj_quant_args.code_scale_ptr = const_cast(down_proj_code_scale->data()); + down_proj_quant_args.code_zp_ptr = const_cast(down_proj_code_zp->data()); } auto moe_gemm_runner = MoeGemmRunner();