mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[LLM] First commit the llm deployment code
This commit is contained in:
352
custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp
Normal file
352
custom_ops/gpu_ops/cutlass_extensions/arch/copy_red_global.hpp
Normal file
@@ -0,0 +1,352 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
#include <cute/arch/util.hpp>
|
||||
#include <cute/atom/copy_traits.hpp>
|
||||
#include <cute/numeric/numeric_types.hpp>
|
||||
|
||||
// Config
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
|
||||
#define CUTE_ARCH_RED_F16_SM70_ENABLED
|
||||
#endif
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
|
||||
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
|
||||
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
|
||||
#endif
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
//////////////////////////////////
|
||||
// Wrapper around CUDA's atomicAdd
|
||||
//////////////////////////////////
|
||||
|
||||
template <class T>
|
||||
struct TypedAtomicAdd
|
||||
{
|
||||
using SRegisters = T[1];
|
||||
using DRegisters = T[1];
|
||||
|
||||
CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst)
|
||||
{
|
||||
atomicAdd(&dst, src);
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct Copy_Traits<TypedAtomicAdd<T>>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, Int<sizeof_bits<T>::value>>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// F16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM70_RED_ADD_NOFTZ_F16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
|
||||
asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM70_RED_ADD_NOFTZ_F16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_F16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_F16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
// BF16 ADD PTX
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16
|
||||
{
|
||||
using SRegisters = uint16_t[1];
|
||||
using DRegisters = uint16_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _16>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2
|
||||
{
|
||||
using SRegisters = uint32_t[1];
|
||||
using DRegisters = uint32_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _32>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V2
|
||||
{
|
||||
using SRegisters = uint32_t[2];
|
||||
using DRegisters = uint64_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V2>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _64>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
struct SM90_RED_ADD_NOFTZ_BF16x2_V4
|
||||
{
|
||||
using SRegisters = uint32_t[4];
|
||||
using DRegisters = uint128_t[1];
|
||||
|
||||
CUTE_HOST_DEVICE static void copy(
|
||||
uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst)
|
||||
{
|
||||
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
|
||||
asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1),
|
||||
"r"(src2), "r"(src3));
|
||||
#else
|
||||
CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Copy_Traits<SM90_RED_ADD_NOFTZ_BF16x2_V4>
|
||||
{
|
||||
// Logical thread id to thread idx (one-thread)
|
||||
using ThrID = Layout<_1>;
|
||||
|
||||
// Map from (src-thr,src-val) to bit
|
||||
using SrcLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Map from (dst-thr,dst-val) to bit
|
||||
using DstLayout = Layout<Shape<_1, _128>>;
|
||||
|
||||
// Reference map from (thr,val) to bit
|
||||
using RefLayout = SrcLayout;
|
||||
};
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
} // end namespace cute
|
||||
120
custom_ops/gpu_ops/cutlass_extensions/arch/mma.h
Normal file
120
custom_ops/gpu_ops/cutlass_extensions/arch/mma.h
Normal file
@@ -0,0 +1,120 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates exposing architecture support for multiply-add operations
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace arch
|
||||
{
|
||||
|
||||
// Tag which triggers MMA which will trigger
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA;
|
||||
|
||||
/*
|
||||
Below we have extra tags to signal what kind of dequantization we want to do
|
||||
(per col, scale only fine grained, finegrained with zero). This still lets us
|
||||
the existing template infrastructure (incl. that in CUTLASS). However, we
|
||||
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
|
||||
with the quantization op before instantiating the GEMM pieces.
|
||||
|
||||
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
|
||||
code we need to duplicate.
|
||||
*/
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
|
||||
// The default just forwards the original operator
|
||||
template <typename MmaOp, WeightOnlyQuantOp QuantOp_>
|
||||
struct TagOperator
|
||||
{
|
||||
using TaggedOperator = MmaOp;
|
||||
};
|
||||
|
||||
// Specializations below attach more information to the operator
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TagOperator<OpMultiplyAddDequantizeInterleavedBToA, WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>
|
||||
{
|
||||
using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias;
|
||||
};
|
||||
|
||||
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
|
||||
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
|
||||
// as a default.
|
||||
template <typename TaggedMmaOp>
|
||||
struct DetagOperator
|
||||
{
|
||||
using Operator = TaggedMmaOp;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_percol_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scale>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DetagOperator<OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias>
|
||||
{
|
||||
using Operator = OpMultiplyAddDequantizeInterleavedBToA;
|
||||
static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
};
|
||||
|
||||
} // namespace arch
|
||||
} // namespace cutlass
|
||||
85
custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h
Normal file
85
custom_ops/gpu_ops/cutlass_extensions/compute_occupancy.h
Normal file
@@ -0,0 +1,85 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "common/cudaUtils.h"
|
||||
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
template <typename GemmKernel, bool enable_cutlass_3x = false>
|
||||
inline int compute_occupancy_for_kernel()
|
||||
{
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
cudaFuncAttributes attr;
|
||||
int device = 0;
|
||||
int max_smem_per_block = 0;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&device));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
|
||||
}
|
||||
else
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
|
||||
}
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
|
||||
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
|
||||
// configuration.
|
||||
return 0;
|
||||
}
|
||||
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute(
|
||||
cutlass::device_kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaFuncSetAttribute(
|
||||
cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
|
||||
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
||||
}
|
||||
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
@@ -0,0 +1,105 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Functor performing linear combination with a maximum operation used by epilogues.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/scale_type.h"
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace thread
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
__forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
{
|
||||
float r;
|
||||
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
|
||||
return r;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
{
|
||||
|
||||
float k0 = float(0.7978845608028654);
|
||||
float k1 = float(0.044715);
|
||||
|
||||
return float(cutlass::constants::half<float>() * z
|
||||
* (cutlass::constants::one<float>() + tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
|
||||
}
|
||||
|
||||
using Params = LinearCombinationGenericParams<float>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& scalar, Params const& params_) const
|
||||
{
|
||||
return this->operator()(scalar);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace thread
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,350 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "common/quantization.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
template <typename ThreadblockShape_, int ThreadCount, typename ScaleTileIterator_, typename OutputTileIterator_,
|
||||
typename ElementAccumulator_, typename ElementCompute_, typename ElementwiseFunctor_, bool UseMasking_ = false>
|
||||
class EpilogueVisitorPerRowPerCol
|
||||
{
|
||||
public:
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
|
||||
using ScaleTileIterator = ScaleTileIterator_;
|
||||
using OutputTileIterator = OutputTileIterator_;
|
||||
using ElementwiseFunctor = ElementwiseFunctor_;
|
||||
|
||||
static int const kIterations = OutputTileIterator::kIterations;
|
||||
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
using ElementOutput = typename OutputTileIterator::Element;
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
using ElementAccumulator = ElementAccumulator_;
|
||||
|
||||
using AlphaScaleElementType = typename ScaleTileIterator::Element;
|
||||
|
||||
using ElementCompute = ElementCompute_;
|
||||
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
|
||||
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
|
||||
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
|
||||
|
||||
static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth;
|
||||
static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
Arguments()
|
||||
: batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(0)
|
||||
, batch_stride_C(0)
|
||||
, batch_stride_D(0)
|
||||
{
|
||||
}
|
||||
|
||||
Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_,
|
||||
int64_t batch_stride_C_, int64_t batch_stride_D_)
|
||||
: elementwise(elementwise_)
|
||||
, batch_stride_alpha(batch_stride_alpha_)
|
||||
, batch_stride_C(batch_stride_C_)
|
||||
, batch_stride_D(batch_stride_D_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ElementwiseFunctor::Params elementwise;
|
||||
int64_t batch_stride_alpha;
|
||||
int64_t batch_stride_C;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args)
|
||||
: elementwise(args.elementwise)
|
||||
, batch_stride_alpha(args.batch_stride_alpha)
|
||||
, batch_stride_C(args.batch_stride_C)
|
||||
, batch_stride_D(args.batch_stride_D)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared storage
|
||||
struct SharedStorage
|
||||
{
|
||||
};
|
||||
|
||||
private:
|
||||
Params const& params_;
|
||||
SharedStorage& shared_storage_;
|
||||
MatrixCoord extent_;
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
ScaleTileIterator iterator_alpha_col_;
|
||||
OutputTileIterator iterator_C_;
|
||||
OutputTileIterator iterator_D_;
|
||||
|
||||
AlphaScaleElementType element_alpha_row_ = 1.0f;
|
||||
AlphaScaleElementType element_alpha_col_ = 1.0f;
|
||||
typename ScaleTileIterator::Fragment fragment_alpha_col_;
|
||||
typename OutputTileIterator::Fragment fragment_C_;
|
||||
typename OutputTileIterator::Fragment fragment_D_;
|
||||
|
||||
ElementAccumulator beta_;
|
||||
|
||||
int column_offset_;
|
||||
|
||||
MatrixCoord thread_offset_;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage,
|
||||
cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx,
|
||||
typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C,
|
||||
typename OutputTileIterator::Params params_D, common::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row,
|
||||
AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C,
|
||||
typename OutputTileIterator::Element* ptr_D,
|
||||
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0,
|
||||
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0))
|
||||
: params_(params)
|
||||
, shared_storage_(shared_storage)
|
||||
, extent_(problem_size)
|
||||
, elementwise_(params.elementwise)
|
||||
, per_token_quant_(quant_option.hasPerTokenScaling())
|
||||
, per_channel_quant_(quant_option.hasPerChannelScaling())
|
||||
, ptr_alpha_row_(ptr_alpha_row)
|
||||
, ptr_alpha_col_(ptr_alpha_col)
|
||||
, iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset)
|
||||
, iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset)
|
||||
, extent_real_(problem_size_real)
|
||||
{
|
||||
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta);
|
||||
|
||||
if (beta_ == ElementAccumulator())
|
||||
{
|
||||
iterator_C_.clear_mask();
|
||||
}
|
||||
|
||||
if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr))
|
||||
{
|
||||
element_alpha_col_ = *ptr_alpha_col_;
|
||||
}
|
||||
|
||||
if (!per_token_quant_ && (ptr_alpha_row_ != nullptr))
|
||||
{
|
||||
element_alpha_row_ = *ptr_alpha_row_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to indicate split-K behavior
|
||||
CUTLASS_DEVICE
|
||||
void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme
|
||||
int split_k_slices)
|
||||
{ ///< Total number of split-K slices
|
||||
}
|
||||
|
||||
/// Called to set the batch index
|
||||
CUTLASS_DEVICE
|
||||
void set_batch_index(int batch_idx)
|
||||
{
|
||||
iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha);
|
||||
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
|
||||
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
|
||||
}
|
||||
|
||||
/// Called at the start of the epilogue just before iterating over accumulator slices
|
||||
CUTLASS_DEVICE
|
||||
void begin_epilogue()
|
||||
{
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
iterator_alpha_col_.load(fragment_alpha_col_);
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of one step before starting accumulator exchange
|
||||
CUTLASS_DEVICE
|
||||
void begin_step(int step_idx)
|
||||
{
|
||||
fragment_D_.clear();
|
||||
fragment_C_.clear();
|
||||
|
||||
if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling)
|
||||
{
|
||||
iterator_C_.load(fragment_C_);
|
||||
++iterator_C_;
|
||||
}
|
||||
}
|
||||
|
||||
/// Called at the start of a row
|
||||
CUTLASS_DEVICE
|
||||
void begin_row(int row_idx)
|
||||
{
|
||||
// load alpha_row in begin_step only when per token(row) scaling is used
|
||||
if (per_token_quant_)
|
||||
{
|
||||
int thread_offset_row
|
||||
= iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row();
|
||||
|
||||
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
|
||||
element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row());
|
||||
}
|
||||
}
|
||||
|
||||
/// Called after accumulators have been exchanged for each accumulator vector
|
||||
CUTLASS_DEVICE
|
||||
void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum)
|
||||
{
|
||||
|
||||
NumericArrayConverter<ElementCompute, ElementAccumulator, kElementsPerAccess> source_converter;
|
||||
|
||||
ComputeFragment result = source_converter(accum);
|
||||
if (per_channel_quant_)
|
||||
{
|
||||
ComputeFragment alpha_col = reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[column_idx];
|
||||
result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_);
|
||||
}
|
||||
else
|
||||
{
|
||||
result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_);
|
||||
}
|
||||
|
||||
// Convert to the output
|
||||
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess> output_converter;
|
||||
OutputVector& output = reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
|
||||
output = output_converter(result);
|
||||
}
|
||||
|
||||
/// Called at the end of a row
|
||||
CUTLASS_DEVICE
|
||||
void end_row(int row_idx) {}
|
||||
|
||||
/// Called after all accumulator elements have been visited
|
||||
CUTLASS_DEVICE
|
||||
void end_step(int step_idx)
|
||||
{
|
||||
|
||||
iterator_D_.store(fragment_D_);
|
||||
++iterator_D_;
|
||||
}
|
||||
|
||||
/// Called after all steps have been completed
|
||||
CUTLASS_DEVICE
|
||||
void end_epilogue() {}
|
||||
|
||||
private:
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_channel_scale_accumulator_(
|
||||
ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col[i] * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ComputeFragment per_token_scale_accumulator_(
|
||||
ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row)
|
||||
{
|
||||
|
||||
ComputeFragment result;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < ComputeFragment::kElements; ++i)
|
||||
{
|
||||
result[i] = accum[i] * (scale_col * scale_row);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,282 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
|
||||
|
||||
The epilogue rearranges the result of a matrix product through shared memory to match canonical
|
||||
tensor layouts in global memory. Epilogues support conversion and reduction operations.
|
||||
|
||||
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
|
||||
|
||||
#include "cutlass/epilogue/thread/conversion_op.h"
|
||||
#include "cutlass/epilogue/thread/reduction_op.h"
|
||||
|
||||
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
|
||||
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
|
||||
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
||||
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/epilogue.h"
|
||||
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
|
||||
|
||||
#include "cutlass/layout/permute.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace epilogue
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
|
||||
template <typename ThreadblockShape, typename WarpShape, typename InstructionShape, typename ThreadMap>
|
||||
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t, int32_t, 8, ThreadblockShape, WarpShape, InstructionShape,
|
||||
ThreadMap>
|
||||
{
|
||||
using WarpTileIterator
|
||||
= cutlass::epilogue::warp::TileIteratorTensorOpMixed<WarpShape, InstructionShape, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
using SharedLoadIterator
|
||||
= cutlass::epilogue::threadblock::SharedLoadIteratorMixed<ThreadMap, int32_t, 32, 16, 8, 8>;
|
||||
|
||||
static int const kFragmentsPerIteration = 2;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load output tile from shared memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator
|
||||
///
|
||||
template <typename ThreadMap_ ///< Thread map (concept: OutputTileThreadMap)
|
||||
>
|
||||
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8>
|
||||
{
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = int32_t;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
|
||||
static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
|
||||
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<Element,
|
||||
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup
|
||||
* ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
|
||||
|
||||
/// Vector type used for SMEM loads
|
||||
using LoadType = AlignedArray<Element, const_min(128 / sizeof_bits<Element>::value, ThreadMap::kElementsPerAccess),
|
||||
const_min(16, kAlignment)>;
|
||||
|
||||
static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Byte-level pointer
|
||||
LoadType const* pointers_[kLoadsPerAccess];
|
||||
|
||||
/// Stride along adjacent rows in units of LoadType
|
||||
int stride_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
|
||||
: stride_((ref.stride(0) / LoadType::kElements))
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
|
||||
|
||||
// Initialize pointers
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
|
||||
|
||||
int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
|
||||
int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess;
|
||||
|
||||
col_idx += (bank_offset + i) % kLoadsPerAccess;
|
||||
|
||||
pointers_[i] += thread_offset.row() * stride_ + col_idx;
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i] += pointer_offset / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& offset)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kLoadsPerAccess; ++i)
|
||||
{
|
||||
pointers_[i]
|
||||
+= offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements;
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row)
|
||||
{
|
||||
|
||||
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_
|
||||
+ group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_
|
||||
+ pointer_offset / LoadType::kElements;
|
||||
|
||||
int frag_row_idx
|
||||
= (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column)
|
||||
{
|
||||
|
||||
int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < kLoadsPerAccess; ++v)
|
||||
{
|
||||
|
||||
int vector_idx
|
||||
= (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess);
|
||||
|
||||
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
|
||||
|
||||
frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment& frag) const
|
||||
{
|
||||
|
||||
load_with_pointer_offset(frag, 0);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace epilogue
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
139
custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h
Normal file
139
custom_ops/gpu_ops/cutlass_extensions/epilogue_helpers.h
Normal file
@@ -0,0 +1,139 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
* @file epilogue_helpers.h
|
||||
*
|
||||
* This file includes types for the epilogues. The empty structs exist so we can signal to template
|
||||
* code the type of epilogue we want to run, and let the underlying code specify the details such as
|
||||
* element types, accumulator type and elements per vector access.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_silu.h"
|
||||
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
|
||||
|
||||
// #include "cutlass/epilogue/fusion/operations.hpp"
|
||||
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
struct EpilogueOpBiasSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBiasFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpBias
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultSilu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultReLU
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefaultFtGelu
|
||||
{
|
||||
};
|
||||
|
||||
struct EpilogueOpDefault
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator, typename Op>
|
||||
struct Epilogue
|
||||
{
|
||||
static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag");
|
||||
};
|
||||
|
||||
constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBiasFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, BiasScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpBias>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, BiasScaleMode>;
|
||||
};
|
||||
|
||||
constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default;
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultSilu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationSilu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultReLU>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationRelu<ElementType, ElementsPerVectorAccess,
|
||||
ElementAccumulator, ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefaultFtGelu>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<cutlass::epilogue::thread::GELU_taylor, ElementType,
|
||||
ElementsPerVectorAccess, ElementAccumulator, ElementAccumulator, DefaultScaleMode,
|
||||
cutlass::FloatRoundStyle::round_to_nearest, true>;
|
||||
};
|
||||
|
||||
template <typename ElementType, int ElementsPerVectorAccess, typename ElementAccumulator>
|
||||
struct Epilogue<ElementType, ElementsPerVectorAccess, ElementAccumulator, EpilogueOpDefault>
|
||||
{
|
||||
using Op = cutlass::epilogue::thread::LinearCombination<ElementType, ElementsPerVectorAccess, ElementAccumulator,
|
||||
ElementAccumulator, DefaultScaleMode>;
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Adapt from
|
||||
// https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,196 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/clear.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||
/// from the tensor core accumulators to the main accumulators when the number
|
||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||
/// the tensor core accumulators are zeroed.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
template <
|
||||
class EngineAccum,
|
||||
class LayoutAccum>
|
||||
struct GmmaFP8AccumulationWithScale {
|
||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||
using ElementAccumulator = typename EngineAccum::value_type;
|
||||
|
||||
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
||||
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
||||
|
||||
private:
|
||||
TensorAccum& accum_;
|
||||
TensorAccum accum_temp_;
|
||||
|
||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||
uint32_t mma_count_; // current executed MMAs
|
||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||
|
||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||
CUTLASS_DEVICE
|
||||
void promote_core() {
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i);
|
||||
}
|
||||
}
|
||||
|
||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
||||
|
||||
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
||||
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
||||
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
||||
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scale(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
GmmaFP8AccumulationWithScale(
|
||||
TensorAccum &accum,
|
||||
uint32_t accum_promotion_interval,
|
||||
uint32_t mma_count_per_mainloop_iteration)
|
||||
: accum_(accum),
|
||||
accum_promotion_interval_(accum_promotion_interval),
|
||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||
mma_count_(0),
|
||||
reset_accum_flag_(0)
|
||||
{
|
||||
accum_temp_ = cute::make_fragment_like(accum);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (Common)
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorAccum& operator()() {
|
||||
return accum_temp_;
|
||||
}
|
||||
|
||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||
CUTLASS_DEVICE
|
||||
bool prepare_if_needed() {
|
||||
return reset_accum_flag_;
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FADD version)
|
||||
//
|
||||
|
||||
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_if_needed() {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
promote_core();
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_residue_if_needed() {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
promote_core();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FFMA version)
|
||||
//
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scale);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@@ -0,0 +1,744 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.2/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm80.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using ElementBlockScale = ElementAccumulator;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||
auto tM = get<2>(gA_mkl.shape());
|
||||
auto tN = get<2>(gB_nkl.shape());
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorScaleA, class TensorScaleB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||
// all scales are valid, since we could have a partial tiles along M)
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Block scaling
|
||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||
Layout<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// Layout of warp group to thread mapping
|
||||
|
||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||
|
||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||
Int<NumThreadsPerWarpGroup>{});
|
||||
|
||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,438 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
|
||||
batched array variants.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// #include <limits>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*
|
||||
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
|
||||
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
|
||||
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
|
||||
|
||||
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
|
||||
that feature at the moment.
|
||||
*/
|
||||
|
||||
template <typename GemmKernel_>
|
||||
class GemmUniversalBaseCompat
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
using ThreadblockShape = typename GemmKernel::Mma::Shape;
|
||||
|
||||
using ElementA = typename GemmKernel::ElementA;
|
||||
using LayoutA = typename GemmKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = GemmKernel::kTransformA;
|
||||
|
||||
using ElementB = typename GemmKernel::ElementB;
|
||||
using LayoutB = typename GemmKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = GemmKernel::kTransformB;
|
||||
|
||||
using ElementC = typename GemmKernel::ElementC;
|
||||
using LayoutC = typename GemmKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle;
|
||||
using Operator = typename GemmKernel::Operator;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename GemmKernel::Arguments;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename GemmKernel::Params params_;
|
||||
|
||||
protected:
|
||||
/// Private helper to obtain the grid dimensions with fix-up for split-K
|
||||
static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||||
args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
gemm_k_size = args.problem_size.k();
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalBaseCompat() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1);
|
||||
|
||||
if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax))
|
||||
{
|
||||
|
||||
return Status::kErrorInvalidProblem;
|
||||
}
|
||||
|
||||
return GemmKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()");
|
||||
|
||||
size_t workspace_bytes = 0;
|
||||
|
||||
// Determine grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
// Split-K parallel always requires a temporary workspace
|
||||
workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k());
|
||||
}
|
||||
else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Serial split-K only requires a temporary workspace if the number of partitions along the
|
||||
// GEMM K dimension is greater than one.
|
||||
workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n());
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape);
|
||||
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()");
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape);
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n"
|
||||
<< " result = {" << result << "}");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10))
|
||||
{
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result == cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<GemmKernel>, GemmKernel::kThreadCount, 0);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0)
|
||||
{
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
||||
|
||||
if (workspace_bytes)
|
||||
{
|
||||
|
||||
if (!workspace)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" error: device workspace must not be null");
|
||||
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" clearing device workspace");
|
||||
cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result));
|
||||
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUDA grid shape
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int gemm_k_size = 0;
|
||||
|
||||
get_grid_shape_(grid_tiled_shape, gemm_k_size, args);
|
||||
|
||||
// Initialize the Params structure
|
||||
params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast<int*>(workspace));
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
params_.update(args, workspace);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()");
|
||||
|
||||
//
|
||||
// Configure grid and block dimensions
|
||||
//
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes");
|
||||
|
||||
// Launch
|
||||
cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
|
||||
//
|
||||
// Query for errors
|
||||
//
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,542 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal.h"
|
||||
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace device
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
{
|
||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||
// Note that different requests of in_tensor might have different hidden_size (=m*n)
|
||||
// so, we need to use splitk_buffer_offsets.
|
||||
// out_tensor: problem_idx * [hidden_size]
|
||||
|
||||
int const problem_idx = blockIdx.y;
|
||||
GemmCoord problem = problem_sizes[problem_idx];
|
||||
int const hidden_size = problem.m() * problem.n();
|
||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||
|
||||
for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
float sum = 0.0f;
|
||||
for (int k_idx = 0; k_idx < splitk; k_idx++)
|
||||
{
|
||||
sum += (float) in_tensor_[k_idx * hidden_size + i];
|
||||
}
|
||||
out_tensor_[i] = (T_OUT) (sum);
|
||||
}
|
||||
}
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename BaseKernel_>
|
||||
class BaseSplitkGrouped
|
||||
{
|
||||
public:
|
||||
using BaseKernel = BaseKernel_;
|
||||
|
||||
using ElementA = typename BaseKernel::ElementA;
|
||||
using LayoutA = typename BaseKernel::LayoutA;
|
||||
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||||
static ComplexTransform const kTransformA = BaseKernel::kTransformA;
|
||||
static int const kAlignmentA = BaseKernel::kAlignmentA;
|
||||
|
||||
using ElementB = typename BaseKernel::ElementB;
|
||||
using LayoutB = typename BaseKernel::LayoutB;
|
||||
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||||
static ComplexTransform const kTransformB = BaseKernel::kTransformB;
|
||||
static int const kAlignmentB = BaseKernel::kAlignmentB;
|
||||
|
||||
using ElementC = typename BaseKernel::ElementC;
|
||||
using LayoutC = typename BaseKernel::LayoutC;
|
||||
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||||
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||||
static int const kAlignmentC = BaseKernel::kAlignmentC;
|
||||
|
||||
using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC;
|
||||
|
||||
using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp;
|
||||
using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle;
|
||||
|
||||
using Operator = typename BaseKernel::Operator;
|
||||
using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator;
|
||||
|
||||
using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator;
|
||||
using MathOperator = typename WarpMmaOperator::MathOperator;
|
||||
using OperatorClass = typename WarpMmaOperator::OperatorClass;
|
||||
using ArchTag = typename WarpMmaOperator::ArchTag;
|
||||
using ThreadblockShape = typename BaseKernel::Mma::Shape;
|
||||
using WarpShape = typename BaseKernel::WarpShape;
|
||||
using InstructionShape = typename BaseKernel::InstructionShape;
|
||||
static int const kStages = BaseKernel::Mma::kStages;
|
||||
|
||||
/// Argument structure
|
||||
using Arguments = typename BaseKernel::Arguments;
|
||||
|
||||
using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo;
|
||||
|
||||
protected:
|
||||
/// Kernel parameters object
|
||||
typename BaseKernel::Params gemm_params_;
|
||||
|
||||
private:
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
|
||||
{
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i];
|
||||
BaseKernel::ProblemVisitor::possibly_transpose_problem(problem);
|
||||
tiles += problem_tile_count(problem);
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
/// Copy from `data` to `workspace`
|
||||
Status copy_to_workspace(void* workspace, void* data, size_t bytes)
|
||||
{
|
||||
cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice);
|
||||
if (cuda_error != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
cuda_error = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Precomputes scheduling information for the grouped GEMM
|
||||
Status precompute(Arguments const& args, int32_t tile_count, void* workspace)
|
||||
{
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
std::vector<uint8_t> host_workspace(workspace_bytes);
|
||||
BaseKernel::ProblemVisitor::host_precompute(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data());
|
||||
return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes);
|
||||
}
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, std::vector<size_t> const& indices)
|
||||
{
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
for (size_t i = 0; i < indices.size(); ++i)
|
||||
{
|
||||
copy.at(i) = data[indices[i]];
|
||||
}
|
||||
|
||||
memcpy(data, copy.data(), indices.size() * sizeof(T));
|
||||
}
|
||||
|
||||
public:
|
||||
/// Constructs the GEMM.
|
||||
BaseSplitkGrouped() {}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
return BaseKernel::can_implement(args);
|
||||
}
|
||||
|
||||
/// Get the number of tiles in a problem
|
||||
static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
auto grid = BaseKernel::ProblemVisitor::grid_shape(problem);
|
||||
return BaseKernel::ProblemVisitor::tile_count(grid);
|
||||
}
|
||||
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(Arguments const& args)
|
||||
{
|
||||
if (args.host_problem_sizes == nullptr)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes");
|
||||
return -1;
|
||||
}
|
||||
|
||||
return group_tile_count(args.host_problem_sizes, args.problem_count);
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t get_workspace_size(Arguments const& args)
|
||||
{
|
||||
size_t total_mn = 0;
|
||||
for (int i = 0; i < args.problem_count; i++)
|
||||
{
|
||||
total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n();
|
||||
}
|
||||
size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices;
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size(
|
||||
args.host_problem_sizes, args.problem_count, args.threadblock_count);
|
||||
}
|
||||
return workSpaceSize;
|
||||
}
|
||||
|
||||
/// Computes the grid shape
|
||||
static dim3 get_grid_shape(Arguments const& args)
|
||||
{
|
||||
|
||||
return dim3(args.threadblock_count, 1, 1);
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int smem_capacity = -1)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()");
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10))
|
||||
{
|
||||
result = cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, Kernel<BaseKernel>, BaseKernel::kThreadCount, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Sorts each pointer passed in according to the indices that sort
|
||||
/// `problem_sizes_ptr` in descending order of problem-K dimension.
|
||||
static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr,
|
||||
int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr,
|
||||
int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr)
|
||||
{
|
||||
std::vector<size_t> indices(problem_count);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::stable_sort(indices.begin(), indices.end(),
|
||||
[&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); });
|
||||
|
||||
reorder_array(problem_sizes_ptr, indices);
|
||||
reorder_array(lda_host_ptr, indices);
|
||||
reorder_array(ldb_host_ptr, indices);
|
||||
reorder_array(ldc_host_ptr, indices);
|
||||
reorder_array(ldd_host_ptr, indices);
|
||||
reorder_array(offset_A_ptr, indices);
|
||||
reorder_array(offset_B_ptr, indices);
|
||||
reorder_array(offset_C_ptr, indices);
|
||||
reorder_array(offset_D_ptr, indices);
|
||||
}
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(
|
||||
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
{
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
// Call cudaGetLastError() to clear the error bit
|
||||
result = cudaGetLastError();
|
||||
CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int multiprocessor_count;
|
||||
result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx);
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result));
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count);
|
||||
if (override_sm_count)
|
||||
{
|
||||
available_sm_count = multiprocessor_count;
|
||||
}
|
||||
|
||||
int max_active_blocks = maximum_active_blocks();
|
||||
if (max_active_blocks <= 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int occupancy_based_block_count = available_sm_count * max_active_blocks;
|
||||
|
||||
if (problem_sizes_ptr == nullptr || problem_count == 0)
|
||||
{
|
||||
return occupancy_based_block_count;
|
||||
}
|
||||
|
||||
int total_tiles = group_tile_count(problem_sizes_ptr, problem_count);
|
||||
|
||||
// If the group contains a single problem, launching the exact number of
|
||||
// threadblocks needed to cover the problem minimizes the work performed
|
||||
// per threadblock in finding the next tile to compute. We return total_tiles
|
||||
// unless the user has provided the SM count.
|
||||
if (problem_count == 1 && override_sm_count)
|
||||
{
|
||||
return total_tiles;
|
||||
}
|
||||
|
||||
// Choose between the full wave of threadblocks and the tile count. If there
|
||||
// are fewer tiles in the group than threadblocks in the full wave, only
|
||||
// some threadblocks will be assigned tiles. Those threadblocks
|
||||
// which are not assigned tiles still need to perform the work of iterating through
|
||||
// problem sizes to determine that they have no work to do. This competes for cycles
|
||||
// with those threadblocks that are assigned tiles to compute.
|
||||
return std::min(total_tiles, occupancy_based_block_count);
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Workspace
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_ = typename BaseKernel::Params(args, workspace);
|
||||
}
|
||||
|
||||
// Specify shared memory capacity for kernel.
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t result
|
||||
= cudaFuncSetAttribute(Kernel<BaseKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments
|
||||
Status update(Arguments const& args, void* workspace = nullptr)
|
||||
{
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
|
||||
if (workspace_bytes && !workspace)
|
||||
{
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
if (BaseKernel::ProblemVisitor::kRequiresPrecomputation)
|
||||
{
|
||||
int32_t tile_count = group_tile_count(args);
|
||||
Status status = precompute(args, tile_count, workspace);
|
||||
if (status != Status::kSuccess)
|
||||
{
|
||||
return status;
|
||||
}
|
||||
|
||||
gemm_params_.update(args, workspace, tile_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_params_.update(args, workspace);
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status run(cudaStream_t stream = nullptr)
|
||||
{
|
||||
if (!gemm_params_.problem_visitor.problem_count)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
//
|
||||
// Launch kernel
|
||||
//
|
||||
|
||||
// Launch splitk grouped gemm
|
||||
{
|
||||
dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices);
|
||||
dim3 block(BaseKernel::kThreadCount, 1, 1);
|
||||
|
||||
int smem_size = int(sizeof(typename BaseKernel::SharedStorage));
|
||||
cutlass::Kernel<BaseKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
// Launch splitkReduction
|
||||
{
|
||||
dim3 grid(32, gemm_params_.problem_visitor.problem_count);
|
||||
dim3 block(256);
|
||||
splitkReduction<<<grid, block, 0, stream>>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split,
|
||||
gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices,
|
||||
gemm_params_.splitk_buffer_offsets);
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
|
||||
if (result != cudaSuccess)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Runs the kernel using initialized state.
|
||||
Status operator()(cudaStream_t stream = nullptr)
|
||||
{
|
||||
return run(stream);
|
||||
}
|
||||
|
||||
/// Initializes and runs the kernel.
|
||||
Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr)
|
||||
{
|
||||
|
||||
Status status = initialize(args, workspace, stream);
|
||||
|
||||
if (status == Status::kSuccess)
|
||||
{
|
||||
status = run(stream);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// GEMM Grouped
|
||||
template <typename GemmKernel_>
|
||||
class SplitkGemmGrouped : public BaseSplitkGrouped<GemmKernel_>
|
||||
{
|
||||
public:
|
||||
using GemmKernel = GemmKernel_;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace device
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
||||
: KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM =
|
||||
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
||||
KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<
|
||||
KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@@ -0,0 +1,190 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
|
||||
struct MixedGemmArchTraits
|
||||
{
|
||||
static_assert(dependent_false<arch>, "Unrecognised parameterization");
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct MixedGemmArchTraits<float, float, Arch>
|
||||
{
|
||||
static constexpr int Stages = 2;
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using AccType = float;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 1;
|
||||
static constexpr int ElementsPerAccessB = 1;
|
||||
static constexpr int ElementsPerAccessC = 1;
|
||||
static constexpr int ThreadblockK = 8;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// ========================= Volta Traits ===========================
|
||||
// Volta will always dequantize after the global memory load.
|
||||
// This will instantiate any HMMA tensorcore kernels for Volta.
|
||||
// Note that volta does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm70,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm70>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
|
||||
// and compute will happen in fp16 then will be converted for bf16 output.
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm75,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm75>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm80,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm80>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// ======================= Ada Traits ==============================
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::half_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
// FP8 A/B = fp8, C/D = fp32
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct MixedGemmArchTraits<TypeA, TypeB, cutlass::arch::Sm89,
|
||||
typename cutlass::platform::enable_if<cutlass::platform::is_same<TypeA, cutlass::float_e4m3_t>::value
|
||||
|| cutlass::platform::is_same<TypeA, cutlass::float_e5m2_t>::value>::type>
|
||||
{
|
||||
private:
|
||||
using LayoutDetails = LayoutDetailsB<TypeA, TypeB, cutlass::arch::Sm89>;
|
||||
|
||||
public:
|
||||
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
|
||||
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using AccType = float;
|
||||
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
|
||||
using TypeC = __nv_bfloat16;
|
||||
using LayoutB = typename LayoutDetails::Layout;
|
||||
|
||||
static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits<TypeA>::value;
|
||||
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
|
||||
static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits<TypeC>::value;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits<TypeA>::value>;
|
||||
|
||||
using Operator = typename LayoutDetails::Operator;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,57 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename arch>
|
||||
struct Int8GemmArchTraits
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassSimt;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
};
|
||||
|
||||
// ======================= Turing Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm75>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
};
|
||||
|
||||
// ======================= Ampere Traits ==============================
|
||||
template <>
|
||||
struct Int8GemmArchTraits<cutlass::arch::Sm80>
|
||||
{
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,568 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail
|
||||
{
|
||||
template <typename>
|
||||
inline constexpr bool dependent_false_v = false;
|
||||
}
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
typename KernelArch, ///! The Architecture this kernel is compiled for. Used since SIMT kernels lose top-level
|
||||
/// arch.
|
||||
bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled.
|
||||
>
|
||||
struct GemmFpAIntB
|
||||
{
|
||||
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Mma::LayoutC;
|
||||
using ElementScale = ElementC;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformA;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
|
||||
/// Parameters structure
|
||||
struct Arguments
|
||||
{
|
||||
GemmUniversalMode mode = GemmUniversalMode::kGemm;
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
|
||||
// Control serial split-k
|
||||
int batch_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
// Included so we can use Gemm Universal
|
||||
int batch_stride_D = 0;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor,
|
||||
typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(),
|
||||
int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr,
|
||||
int const* scatter_D_indices = nullptr)
|
||||
: problem_size(problem_size)
|
||||
, group_size(group_size)
|
||||
, ref_A(ref_A)
|
||||
, ref_B(ref_B)
|
||||
, ref_scale(ref_scale)
|
||||
, ref_zero(ref_zero)
|
||||
, ref_C(ref_C)
|
||||
, ref_D(ref_D)
|
||||
, batch_count(serial_split_k_factor)
|
||||
, output_op(output_op)
|
||||
, gather_A_indices(gather_A_indices)
|
||||
, gather_B_indices(gather_B_indices)
|
||||
, scatter_D_indices(scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
int group_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorA::TensorRef ref_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename Mma::IteratorB::TensorRef ref_B;
|
||||
typename Mma::IteratorScale::Params params_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_scale;
|
||||
typename Mma::IteratorScale::TensorRef ref_zero;
|
||||
typename Epilogue::OutputTileIterator::Params params_C;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C;
|
||||
typename Epilogue::OutputTileIterator::Params params_D;
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_D;
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
int* semaphore;
|
||||
int gemm_k_size;
|
||||
// For gather+scatter operations
|
||||
int const* gather_A_indices;
|
||||
int const* gather_B_indices;
|
||||
int const* scatter_D_indices;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, semaphore(0)
|
||||
, gemm_k_size(0)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||
void* workspace = nullptr)
|
||||
: problem_size(args.problem_size)
|
||||
, group_size(args.group_size)
|
||||
, grid_tiled_shape(grid_tiled_shape)
|
||||
, swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape))
|
||||
, params_A(args.ref_A.layout())
|
||||
, ref_A(args.ref_A)
|
||||
, params_B(args.ref_B.layout())
|
||||
, ref_B(args.ref_B)
|
||||
, params_scale(args.ref_scale.layout())
|
||||
, ref_scale(args.ref_scale)
|
||||
, ref_zero(args.ref_zero)
|
||||
, params_C(args.ref_C.layout())
|
||||
, ref_C(args.ref_C)
|
||||
, params_D(args.ref_D.layout())
|
||||
, ref_D(args.ref_D)
|
||||
, output_op(args.output_op)
|
||||
, semaphore(static_cast<int*>(workspace))
|
||||
, gemm_k_size(gemm_k_size)
|
||||
, gather_A_indices(args.gather_A_indices)
|
||||
, gather_B_indices(args.gather_B_indices)
|
||||
, scatter_D_indices(args.scatter_D_indices)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
GemmFpAIntB() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
static int const kAlignmentA
|
||||
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB
|
||||
= (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorB::Layout, layout::RowMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Mma::IteratorB::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements;
|
||||
|
||||
static int const kAlignmentC = (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<32>>::value)
|
||||
? 32
|
||||
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
|
||||
layout::ColumnMajorInterleaved<64>>::value)
|
||||
? 64
|
||||
: Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
if (!TensorRef_aligned(args.ref_A, kAlignmentA))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_B, kAlignmentB))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_zero, kAlignmentScale))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_C, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!TensorRef_aligned(args.ref_D, kAlignmentC))
|
||||
{
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (!args.ref_scale.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
|
||||
if constexpr (hasZero(Mma::QuantOp))
|
||||
{
|
||||
if (!args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (args.ref_zero.good())
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (isFinegrained(Mma::QuantOp))
|
||||
{
|
||||
if (args.group_size != 64 && args.group_size != 128)
|
||||
{
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
|
||||
// has a different constructor signature than a regular cutlass iterator
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size);
|
||||
}
|
||||
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<!isFinegrained(op), bool> = true>
|
||||
CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params,
|
||||
typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero,
|
||||
typename IteratorScale::TensorCoord extent, int thread_id,
|
||||
typename IteratorScale::TensorCoord const& threadblock_offset, int group_size)
|
||||
{
|
||||
|
||||
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ == 890)
|
||||
run_kernel<arch::Sm89>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,93 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Scheduler for grouped GEMM
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount,
|
||||
bool Transposed = false>
|
||||
struct GemmMoeProblemVisitor
|
||||
: public MoeProblemVisitor<
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount> {
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
using ProblemSizeHelper =
|
||||
detail::GemmGroupedProblemSizeHelper<ThreadblockShape, Transposed>;
|
||||
using Base = MoeProblemVisitor<ProblemSizeHelper,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode_,
|
||||
PrefetchTileCount,
|
||||
ThreadCount>;
|
||||
using Params = typename Base::Params;
|
||||
using SharedStorage = typename Base::SharedStorage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
GemmMoeProblemVisitor(Params const& params_,
|
||||
SharedStorage& shared_storage_, // NOLINT
|
||||
int32_t block_idx)
|
||||
: Base(params_, shared_storage_, block_idx) {}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,587 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief GEMM kernel to support the epilogue visitor model
|
||||
for customized softmax partial reduction epilogue fusion.
|
||||
|
||||
This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once
|
||||
its usage has been stabilized. For now, it is included in this example to demonstrate
|
||||
some basic output fusion options.
|
||||
|
||||
original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h"
|
||||
|
||||
namespace tk = common;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_ ///! Threadblock swizzling function
|
||||
>
|
||||
struct GemmWithEpilogueVisitor
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueVisitor = typename Epilogue::Visitor;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using TensorRefA = TensorRef<ElementA, LayoutA>;
|
||||
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using TensorRefB = TensorRef<ElementB, LayoutB>;
|
||||
|
||||
using ElementCompute = typename EpilogueVisitor::ElementCompute;
|
||||
using LayoutAlphaCol = cutlass::layout::RowMajor;
|
||||
using LayoutAlphaRow = cutlass::layout::ColumnMajor;
|
||||
using TensorRefAlphaCol = TensorRef<ElementCompute, LayoutAlphaCol>;
|
||||
using TensorRefAlphaRow = TensorRef<ElementCompute, LayoutAlphaRow>;
|
||||
|
||||
using ElementC = typename EpilogueVisitor::ElementOutput;
|
||||
using LayoutC = typename Epilogue::Layout;
|
||||
using TensorRefC = TensorRef<ElementC, LayoutC>;
|
||||
|
||||
static ComplexTransform const kTransformA = Mma::kTransformA;
|
||||
static ComplexTransform const kTransformB = Mma::kTransformB;
|
||||
using Operator = typename Mma::Operator;
|
||||
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
using EpilogueOutputOp =
|
||||
typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
/// Split-K preserves splits that are 128b aligned
|
||||
static int const kSplitKAlignment
|
||||
= const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmUniversalMode mode;
|
||||
GemmCoord problem_size;
|
||||
int batch_count;
|
||||
|
||||
TensorRefA ref_A;
|
||||
TensorRefB ref_B;
|
||||
tk::QuantMode quant_option;
|
||||
TensorRefAlphaCol ref_alpha_col;
|
||||
TensorRefAlphaRow ref_alpha_row;
|
||||
TensorRefC ref_C;
|
||||
TensorRefC ref_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
int64_t batch_stride_D;
|
||||
|
||||
typename EpilogueVisitor::Arguments epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
Arguments()
|
||||
: mode(GemmUniversalMode::kGemm)
|
||||
, batch_count(1)
|
||||
{
|
||||
}
|
||||
|
||||
/// constructs an arguments structure
|
||||
Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_,
|
||||
TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_,
|
||||
TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_,
|
||||
int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_)
|
||||
: mode(mode_)
|
||||
, problem_size(problem_size_)
|
||||
, batch_count(batch_count_)
|
||||
, ref_A(ref_A_)
|
||||
, ref_B(ref_B_)
|
||||
, quant_option(quant_option_)
|
||||
, ref_alpha_col(ref_alpha_col_)
|
||||
, ref_alpha_row(ref_alpha_row_)
|
||||
, ref_C(ref_C_)
|
||||
, ref_D(ref_D_)
|
||||
, batch_stride_A(batch_stride_A_)
|
||||
, batch_stride_B(batch_stride_B_)
|
||||
, batch_stride_D(0)
|
||||
, epilogue_visitor(epilogue_visitor_)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
cutlass::gemm::GemmCoord problem_size;
|
||||
cutlass::gemm::GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
|
||||
typename Mma::IteratorA::Params params_A;
|
||||
typename Mma::IteratorB::Params params_B;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_C;
|
||||
typename EpilogueVisitor::OutputTileIterator::Params params_D;
|
||||
|
||||
GemmUniversalMode mode;
|
||||
int batch_count;
|
||||
int gemm_k_size;
|
||||
|
||||
void* ptr_A;
|
||||
void* ptr_B;
|
||||
tk::QuantMode quant_option;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col;
|
||||
typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row;
|
||||
ElementC* ptr_C;
|
||||
ElementC* ptr_D;
|
||||
|
||||
int64_t batch_stride_A;
|
||||
int64_t batch_stride_B;
|
||||
|
||||
typename EpilogueVisitor::Params epilogue_visitor;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: swizzle_log_tile(0)
|
||||
, params_A(0)
|
||||
, params_B(0)
|
||||
, params_alpha_col(0)
|
||||
, params_C(0)
|
||||
, params_D(0)
|
||||
, batch_count(0)
|
||||
, gemm_k_size(0)
|
||||
, mode(cutlass::gemm::GemmUniversalMode::kGemm)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_alpha_col(nullptr)
|
||||
, ptr_alpha_row(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, batch_stride_A(0)
|
||||
, batch_stride_B(0)
|
||||
{
|
||||
}
|
||||
|
||||
Params(
|
||||
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_)
|
||||
: problem_size(args.problem_size)
|
||||
, swizzle_log_tile(0)
|
||||
, params_A(args.ref_A.layout())
|
||||
, params_B(args.ref_B.layout())
|
||||
, params_alpha_col(args.ref_alpha_col.layout())
|
||||
, params_alpha_row(args.ref_alpha_col.layout())
|
||||
, params_C(args.ref_C.layout())
|
||||
, params_D(args.ref_D.layout())
|
||||
, mode(args.mode)
|
||||
, batch_count(args.batch_count)
|
||||
, gemm_k_size(args.problem_size.k())
|
||||
, ptr_A(args.ref_A.data())
|
||||
, ptr_B(args.ref_B.data())
|
||||
, quant_option(args.quant_option)
|
||||
, ptr_alpha_col(args.ref_alpha_col.data())
|
||||
, ptr_alpha_row(args.ref_alpha_row.data())
|
||||
, ptr_C(args.ref_C.data())
|
||||
, ptr_D(args.ref_D.data())
|
||||
, batch_stride_A(args.batch_stride_A)
|
||||
, batch_stride_B(args.batch_stride_B)
|
||||
, epilogue_visitor(args.epilogue_visitor)
|
||||
{
|
||||
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size,
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count);
|
||||
|
||||
if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
int const kAlignK
|
||||
= const_max(const_max(128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value), 1);
|
||||
|
||||
gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK);
|
||||
|
||||
if (gemm_k_size)
|
||||
{
|
||||
grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size);
|
||||
}
|
||||
}
|
||||
|
||||
swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
union SharedStorage
|
||||
{
|
||||
|
||||
typename Mma::SharedStorage main_loop;
|
||||
|
||||
struct
|
||||
{
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
typename EpilogueVisitor::SharedStorage visitor;
|
||||
} epilogue;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
GemmWithEpilogueVisitor() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()");
|
||||
|
||||
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
|
||||
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
|
||||
static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
bool isAMisaligned = false;
|
||||
bool isBMisaligned = false;
|
||||
bool isCMisaligned = false;
|
||||
|
||||
if (platform::is_same<LayoutA, layout::RowMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajor>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.m() % kAlignmentA;
|
||||
}
|
||||
else if (platform::is_same<LayoutA, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutA, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isAMisaligned = problem_size.k() % kAlignmentA;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutB, layout::RowMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.n() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::ColumnMajor>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
else if (platform::is_same<LayoutB, layout::RowMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutB, layout::RowMajorInterleaved<64>>::value)
|
||||
{
|
||||
isBMisaligned = problem_size.k() % kAlignmentB;
|
||||
}
|
||||
|
||||
if (platform::is_same<LayoutC, layout::RowMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajor>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.m() % kAlignmentC;
|
||||
}
|
||||
else if (platform::is_same<LayoutC, layout::ColumnMajorInterleaved<32>>::value
|
||||
|| platform::is_same<LayoutC, layout::ColumnMajorInterleaved<64>>::value)
|
||||
{
|
||||
isCMisaligned = problem_size.n() % kAlignmentC;
|
||||
}
|
||||
|
||||
if (isAMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isBMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
if (isCMisaligned)
|
||||
{
|
||||
CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand");
|
||||
return Status::kErrorMisalignedOperand;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning kSuccess");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return can_implement(args.problem_size);
|
||||
}
|
||||
|
||||
static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape)
|
||||
{
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#define SPLIT_K_ENABLED 1
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int offset_k = 0;
|
||||
int problem_size_k = params.problem_size.k();
|
||||
|
||||
ElementA* ptr_A = static_cast<ElementA*>(params.ptr_A);
|
||||
ElementB* ptr_B = static_cast<ElementB*>(params.ptr_B);
|
||||
|
||||
#if SPLIT_K_ENABLED
|
||||
//
|
||||
// Fetch pointers based on mode.
|
||||
//
|
||||
if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel)
|
||||
{
|
||||
|
||||
if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k())
|
||||
{
|
||||
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
offset_k = threadblock_tile_offset.k() * params.gemm_k_size;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched)
|
||||
{
|
||||
ptr_A += threadblock_tile_offset.k() * params.batch_stride_A;
|
||||
ptr_B += threadblock_tile_offset.k() * params.batch_stride_B;
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
ptr_A = static_cast<ElementA* const*>(params.ptr_A)[threadblock_tile_offset.k()];
|
||||
ptr_B = static_cast<ElementB* const*>(params.ptr_B)[threadblock_tile_offset.k()];
|
||||
}
|
||||
#endif
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
offset_k,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
//
|
||||
// Construct the epilogue visitor
|
||||
//
|
||||
|
||||
EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor,
|
||||
params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C,
|
||||
params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C,
|
||||
params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m());
|
||||
|
||||
if (params.mode == GemmUniversalMode::kGemm)
|
||||
{
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray)
|
||||
{
|
||||
epilogue_visitor.set_batch_index(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Construct the epilogue
|
||||
Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm72>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/*
|
||||
This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is
|
||||
quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices
|
||||
to be consumed by CUTLASS.
|
||||
|
||||
Note that for int4, ThreadBlockK MUST be 64.
|
||||
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
template <typename TypeA, typename TypeB, typename Arch, typename Enable = void>
|
||||
struct LayoutDetailsB
|
||||
{
|
||||
};
|
||||
|
||||
// Volta specialiations. Volta will dequantize before STS, so we need a different operator
|
||||
template <typename TypeA, typename TypeB>
|
||||
struct LayoutDetailsB<TypeA, TypeB, arch::Sm70>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 8;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks.
|
||||
// TODO - Switch this to column major for weights since gemms should be more performant.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, half_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, bfloat16_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<bfloat16_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA>
|
||||
struct LayoutDetailsB<TypeA, cutlass::float_e4m3_t, arch::Sm89>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<cutlass::float_e4m3_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
// for fast accumulation
|
||||
// using Operator = cutlass::arch::OpMultiplyAddFastAccum;
|
||||
};
|
||||
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint8_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint8_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint8_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB < TypeA,
|
||||
uint4b_t, Arch,
|
||||
typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
|
||||
private:
|
||||
static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits<uint4b_t>::value;
|
||||
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
|
||||
|
||||
public:
|
||||
using Layout = layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<uint4b_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename TypeA, typename Arch>
|
||||
struct LayoutDetailsB<TypeA, uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits<TypeA>::value;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,357 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Base scheduler for grouped problems, using MoE
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/kernel/grouped_problem_visitor.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace kernel {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Visitor class to abstract away the algorithm for iterating over tiles
|
||||
template <typename ProblemSizeHelper, typename ThreadblockShape_>
|
||||
struct BaseMoeProblemVisitor {
|
||||
using ThreadblockShape = ThreadblockShape_;
|
||||
|
||||
struct ProblemInfo {
|
||||
static int32_t const kNoPrefetchEntry = -1;
|
||||
int32_t problem_idx;
|
||||
int32_t problem_start;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo()
|
||||
: problem_idx(kNoPrefetchEntry), problem_start(kNoPrefetchEntry) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
ProblemInfo(int32_t problem_idx_, int32_t problem_start_)
|
||||
: problem_idx(problem_idx_), problem_start(problem_start_) {}
|
||||
};
|
||||
|
||||
struct Params {
|
||||
int64_t const* last_row_for_problem;
|
||||
int64_t total_rows;
|
||||
int64_t gemm_n;
|
||||
int64_t gemm_k;
|
||||
int32_t problem_count;
|
||||
void const* workspace;
|
||||
int32_t tile_count;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: last_row_for_problem(nullptr),
|
||||
total_rows(-1),
|
||||
gemm_n(0),
|
||||
gemm_k(0),
|
||||
problem_count(0),
|
||||
workspace(nullptr),
|
||||
tile_count(0) {}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(int64_t const* last_row_for_problem,
|
||||
int64_t total_rows,
|
||||
int64_t gemm_n,
|
||||
int64_t gemm_k,
|
||||
int32_t problem_count,
|
||||
void const* workspace = nullptr,
|
||||
int32_t tile_count = 0)
|
||||
: last_row_for_problem(last_row_for_problem),
|
||||
total_rows(total_rows),
|
||||
gemm_n(gemm_n),
|
||||
gemm_k(gemm_k),
|
||||
problem_count(problem_count),
|
||||
workspace(workspace),
|
||||
tile_count(tile_count) {}
|
||||
};
|
||||
|
||||
Params const& params;
|
||||
int32_t tile_idx;
|
||||
int32_t problem_tile_start;
|
||||
int32_t problem_idx;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
BaseMoeProblemVisitor(Params const& params_, int32_t block_idx)
|
||||
: params(params_),
|
||||
tile_idx(block_idx),
|
||||
problem_tile_start(0),
|
||||
problem_idx(0) {}
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(
|
||||
const cutlass::gemm::GemmCoord& problem) {
|
||||
return cutlass::gemm::GemmCoord(
|
||||
((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN),
|
||||
1);
|
||||
}
|
||||
|
||||
/// Gets the global tile index
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t tile_index() const { return tile_idx; }
|
||||
|
||||
/// Gets the index of the problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t problem_index() const { return problem_idx; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
int32_t threadblock_idx() const { return tile_idx - problem_tile_start; }
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void advance(int32_t grid_size) { tile_idx += grid_size; }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static void possibly_transpose_problem(
|
||||
cutlass::gemm::GemmCoord& problem) { // NOLINT
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
}
|
||||
|
||||
/// Returns the problem size for the current problem
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size() const {
|
||||
return problem_size(problem_idx);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
cutlass::gemm::GemmCoord problem_size(int idx) const {
|
||||
|
||||
int64_t gemm_m = 0;
|
||||
|
||||
if (params.total_rows < 0) {
|
||||
const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1];
|
||||
const int64_t current_problem_row = params.last_row_for_problem[idx];
|
||||
gemm_m = current_problem_row - prev_problem_row;
|
||||
} else {
|
||||
gemm_m = params.last_row_for_problem[idx];
|
||||
}
|
||||
|
||||
GemmCoord problem(GemmCoord::Index(gemm_m),
|
||||
GemmCoord::Index(params.gemm_n),
|
||||
GemmCoord::Index(params.gemm_k));
|
||||
ProblemSizeHelper::possibly_transpose_problem(problem);
|
||||
return problem;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) {
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count) {
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i) {
|
||||
auto problem = host_problem_sizes_ptr[i];
|
||||
possibly_transpose_problem(problem);
|
||||
auto grid = grid_shape(problem);
|
||||
total_tiles += tile_count(grid);
|
||||
}
|
||||
|
||||
return total_tiles;
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename ProblemSizeHelper,
|
||||
typename ThreadblockShape,
|
||||
GroupScheduleMode GroupScheduleMode_,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount>
|
||||
struct MoeProblemVisitor;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// ProblemVisitor that performs all scheduling on device
|
||||
//
|
||||
template <typename ProblemSizeHelper,
|
||||
typename ThreadblockShape,
|
||||
int PrefetchTileCount,
|
||||
int ThreadCount>
|
||||
struct MoeProblemVisitor<ProblemSizeHelper,
|
||||
ThreadblockShape,
|
||||
GroupScheduleMode::kDeviceOnly,
|
||||
PrefetchTileCount,
|
||||
ThreadCount>
|
||||
: public BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape> {
|
||||
using Base = BaseMoeProblemVisitor<ProblemSizeHelper, ThreadblockShape>;
|
||||
using Params = typename Base::Params;
|
||||
static int const kThreadCount = ThreadCount;
|
||||
static bool const kRequiresPrecomputation = false;
|
||||
static int const kThreadsPerWarp = 32;
|
||||
|
||||
struct SharedStorage {};
|
||||
|
||||
// Final tile of the problem loaded by this thread. Each thread will hold
|
||||
// a separate value.
|
||||
int32_t problem_ending_tile;
|
||||
|
||||
SharedStorage& shared_storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
CUTLASS_DEVICE
|
||||
MoeProblemVisitor(Params const& params_,
|
||||
SharedStorage& shared_storage_, // NOLINT
|
||||
int32_t block_idx)
|
||||
: Base(params_, block_idx),
|
||||
problem_ending_tile(0),
|
||||
shared_storage(shared_storage_) {
|
||||
this->problem_idx = -1 * kThreadsPerWarp;
|
||||
this->problem_tile_start = 0;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool next_tile() {
|
||||
// Check whether the tile to compute is within the range of the current
|
||||
// problem.
|
||||
int32_t problem_tile_end = __shfl_sync(
|
||||
0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp);
|
||||
if (this->tile_idx < problem_tile_end) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check whether the tile to compute is within the current group of problems
|
||||
// fetched by the warp. The last tile for this group is the final tile of
|
||||
// the problem held by the final thread in the warp.
|
||||
int32_t group_tile_end =
|
||||
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
// Keep the starting problem for this group in `problem_idx`. This is done
|
||||
// to reduce register pressure. The starting problem for this group is
|
||||
// simply the first problem in the group most recently fetched by the warp.
|
||||
int32_t& group_problem_start = this->problem_idx;
|
||||
group_problem_start =
|
||||
(this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp;
|
||||
|
||||
// Keep the starting tile for this group in `problem_tile_start`. This is
|
||||
// done to reduce register pressure.
|
||||
int32_t& group_tile_start = this->problem_tile_start;
|
||||
|
||||
// Each thread in the warp processes a separate problem to advance until
|
||||
// reaching a problem whose starting tile is less less than tile_idx.
|
||||
while (group_tile_end <= this->tile_idx) {
|
||||
group_problem_start += kThreadsPerWarp;
|
||||
if (group_problem_start > this->params.problem_count) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Since `group_tile_start` is a reference to `this->problem_tile_start`,
|
||||
// this also sets `this->problem_tile_start`. The fact that
|
||||
// `this->problem_tile_start` is also set here is used later in
|
||||
// `next_tile`.
|
||||
group_tile_start = group_tile_end;
|
||||
|
||||
int lane_idx = threadIdx.x % kThreadsPerWarp;
|
||||
int32_t lane_problem = group_problem_start + lane_idx;
|
||||
|
||||
// Compute the number of tiles in the problem assigned to each thread.
|
||||
problem_ending_tile = 0;
|
||||
if (lane_problem < this->params.problem_count) {
|
||||
cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem);
|
||||
cutlass::gemm::GemmCoord grid = this->grid_shape(problem);
|
||||
problem_ending_tile = this->tile_count(grid);
|
||||
}
|
||||
|
||||
// Compute a warp-wide inclusive prefix sum to compute the ending tile
|
||||
// index of each thread's problem.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 1; i < kThreadsPerWarp; i <<= 1) {
|
||||
int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i);
|
||||
if (lane_idx >= i) {
|
||||
problem_ending_tile += val;
|
||||
}
|
||||
}
|
||||
|
||||
// The total tile count for this group is now in the final position of the
|
||||
// prefix sum
|
||||
int32_t tiles_in_group =
|
||||
__shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1);
|
||||
|
||||
problem_ending_tile += group_tile_start;
|
||||
group_tile_end += tiles_in_group;
|
||||
}
|
||||
|
||||
// The next problem to process is the first one that does not have ending
|
||||
// tile position that is greater than or equal to tile index.
|
||||
int32_t problem_idx_in_group = __popc(
|
||||
__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx));
|
||||
|
||||
this->problem_idx = group_problem_start + problem_idx_in_group;
|
||||
|
||||
// The starting tile for this problem is the ending tile of the previous
|
||||
// problem. In cases where `problem_idx_in_group` is the first problem in
|
||||
// the group, we do not need to reset `problem_tile_start`, because it is
|
||||
// set to the previous group's ending tile in the while loop above.
|
||||
if (problem_idx_in_group > 0) {
|
||||
this->problem_tile_start = __shfl_sync(
|
||||
0xffffffff, problem_ending_tile, problem_idx_in_group - 1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count,
|
||||
int32_t block_count) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr,
|
||||
int32_t problem_count,
|
||||
int32_t block_count,
|
||||
void* host_workspace_ptr) {}
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,494 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/complex.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/semaphore.h"
|
||||
|
||||
#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h"
|
||||
#include "cutlass/gemm/kernel/gemm_transpose_operands.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/trace.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace kernel
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
|
||||
typename Epilogue_, ///! Epilogue
|
||||
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
|
||||
GroupScheduleMode GroupScheduleMode_, ///! Type of scheduling to perform
|
||||
bool Transposed = false>
|
||||
struct SplitkGemmGrouped
|
||||
{
|
||||
public:
|
||||
using Mma = Mma_;
|
||||
using Epilogue = Epilogue_;
|
||||
using EpilogueOutputOp = typename Epilogue::OutputOp;
|
||||
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||||
static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_;
|
||||
static bool const kTransposed = Transposed;
|
||||
|
||||
// Optional transpose
|
||||
using MapArguments = kernel::detail::MapArguments<typename Mma::IteratorA::Element, typename Mma::IteratorA::Layout,
|
||||
Mma::kTransformA, Mma::IteratorA::AccessType::kElements, typename Mma::IteratorB::Element,
|
||||
typename Mma::IteratorB::Layout, Mma::kTransformB, Mma::IteratorB::AccessType::kElements, typename Mma::LayoutC,
|
||||
kTransposed>;
|
||||
|
||||
// Public-facing type definitions related to operand element type, layout, and complex conjugate
|
||||
// operation. Must interact with the 'kTransposed' notion.
|
||||
using ElementA = typename MapArguments::ElementA;
|
||||
using LayoutA = typename MapArguments::LayoutA;
|
||||
using ElementB = typename MapArguments::ElementB;
|
||||
using LayoutB = typename MapArguments::LayoutB;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename MapArguments::LayoutC;
|
||||
|
||||
using ElementFinalOutput = typename MapArguments::ElementA;
|
||||
|
||||
static ComplexTransform const kTransformA = MapArguments::kTransformA;
|
||||
static ComplexTransform const kTransformB = MapArguments::kTransformB;
|
||||
|
||||
// Type definitions about the mainloop.
|
||||
using Operator = typename Mma::Operator;
|
||||
using OperatorClass = typename Mma::Operator::OperatorClass;
|
||||
using ThreadblockShape = typename Mma::Shape;
|
||||
using WarpShape = typename Mma::Operator::Shape;
|
||||
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
|
||||
using ArchTag = typename Mma::ArchTag;
|
||||
|
||||
static int const kStages = Mma::kStages;
|
||||
static int const kAlignmentA = MapArguments::kAlignmentA;
|
||||
static int const kAlignmentB = MapArguments::kAlignmentB;
|
||||
static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess;
|
||||
|
||||
/// Warp count (concept: GemmShape)
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static int const kThreadCount = 32 * WarpCount::kCount;
|
||||
|
||||
using ProblemVisitor
|
||||
= GemmGroupedProblemVisitor<ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount, kTransposed>;
|
||||
|
||||
//
|
||||
// Structures
|
||||
//
|
||||
|
||||
/// Argument structure
|
||||
struct Arguments
|
||||
{
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
GemmCoord* problem_sizes;
|
||||
int problem_count;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
// Only used by device-level operator
|
||||
GemmCoord* host_problem_sizes;
|
||||
|
||||
// splitK
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments()
|
||||
: problem_count(0)
|
||||
, threadblock_count(0)
|
||||
, ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count,
|
||||
typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C,
|
||||
ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda,
|
||||
typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc,
|
||||
typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
: problem_sizes(problem_sizes)
|
||||
, problem_count(problem_count)
|
||||
, threadblock_count(threadblock_count)
|
||||
, output_op(output_op)
|
||||
, ptr_A(ptr_A)
|
||||
, ptr_B(ptr_B)
|
||||
, ptr_C(ptr_C)
|
||||
, ptr_D(ptr_D)
|
||||
, lda(lda)
|
||||
, ldb(ldb)
|
||||
, ldc(ldc)
|
||||
, ldd(ldd)
|
||||
, host_problem_sizes(host_problem_sizes)
|
||||
, split_k_slices(split_k_slices)
|
||||
, splitk_buffer_offsets(splitk_buffer_offsets)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// Structure for precomputing values in host memory and passing to kernels
|
||||
//
|
||||
|
||||
/// Parameters structure
|
||||
struct Params
|
||||
{
|
||||
|
||||
typename ProblemVisitor::Params problem_visitor;
|
||||
int threadblock_count;
|
||||
|
||||
typename EpilogueOutputOp::Params output_op;
|
||||
|
||||
ElementA** ptr_A;
|
||||
ElementB** ptr_B;
|
||||
ElementFinalOutput** ptr_C;
|
||||
ElementFinalOutput** ptr_D;
|
||||
ElementC* ptr_C_split;
|
||||
ElementC* ptr_D_split;
|
||||
|
||||
typename LayoutA::Stride::LongIndex* lda;
|
||||
typename LayoutB::Stride::LongIndex* ldb;
|
||||
typename LayoutC::Stride::LongIndex* ldc;
|
||||
typename LayoutC::Stride::LongIndex* ldd;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
// splitk
|
||||
GemmCoord grid_tiled_shape;
|
||||
int swizzle_log_tile;
|
||||
int gemm_k_size;
|
||||
GemmCoord* host_problem_sizes;
|
||||
int split_k_slices;
|
||||
int64_t* splitk_buffer_offsets;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params()
|
||||
: ptr_A(nullptr)
|
||||
, ptr_B(nullptr)
|
||||
, ptr_C(nullptr)
|
||||
, ptr_D(nullptr)
|
||||
, ptr_C_split(nullptr)
|
||||
, ptr_D_split(nullptr)
|
||||
, lda(nullptr)
|
||||
, ldb(nullptr)
|
||||
, ldc(nullptr)
|
||||
, ldd(nullptr)
|
||||
, swizzle_log_tile(0)
|
||||
, gemm_k_size(0)
|
||||
, host_problem_sizes(nullptr)
|
||||
, split_k_slices(1)
|
||||
, splitk_buffer_offsets(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
: problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count)
|
||||
, host_problem_sizes(args.host_problem_sizes)
|
||||
, threadblock_count(args.threadblock_count)
|
||||
, output_op(args.output_op)
|
||||
, ptr_A(args.ptr_A)
|
||||
, ptr_B(args.ptr_B)
|
||||
, ptr_C(args.ptr_C)
|
||||
, ptr_D(args.ptr_D)
|
||||
, ptr_C_split((ElementC*) workspace)
|
||||
, ptr_D_split((ElementC*) workspace)
|
||||
, lda(args.lda)
|
||||
, ldb(args.ldb)
|
||||
, ldc(args.ldc)
|
||||
, ldd(args.ldd)
|
||||
, split_k_slices(args.split_k_slices)
|
||||
, splitk_buffer_offsets(args.splitk_buffer_offsets)
|
||||
{
|
||||
// Determine grid shape
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0],
|
||||
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices);
|
||||
swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape);
|
||||
|
||||
// only support same k
|
||||
int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK;
|
||||
int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k();
|
||||
|
||||
gemm_k_size = gemm_k_iterations * Mma::Shape::kK;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0)
|
||||
{
|
||||
|
||||
problem_visitor =
|
||||
typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count);
|
||||
threadblock_count = args.threadblock_count;
|
||||
output_op = args.output_op;
|
||||
ptr_A = args.ptr_A;
|
||||
ptr_B = args.ptr_B;
|
||||
ptr_C = args.ptr_C;
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_C_split = workspace;
|
||||
ptr_D_split = workspace;
|
||||
|
||||
lda = args.lda;
|
||||
ldb = args.ldb;
|
||||
ldc = args.ldc;
|
||||
ldd = args.ldd;
|
||||
}
|
||||
};
|
||||
|
||||
/// Shared memory storage structure
|
||||
struct SharedStorage
|
||||
{
|
||||
union
|
||||
{
|
||||
typename Mma::SharedStorage main_loop;
|
||||
typename Epilogue::SharedStorage epilogue;
|
||||
} kernel;
|
||||
|
||||
// ProblemVisitor shared storage can't be overlapped with others
|
||||
typename ProblemVisitor::SharedStorage problem_visitor;
|
||||
};
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
SplitkGemmGrouped() {}
|
||||
|
||||
/// Determines whether kernel satisfies alignment
|
||||
static Status can_implement(cutlass::gemm::GemmCoord const& problem_size)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
ElementA* ptr_A
|
||||
= reinterpret_cast<ElementA*>((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx]));
|
||||
typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]);
|
||||
|
||||
ElementB* ptr_B
|
||||
= reinterpret_cast<ElementB*>((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx]));
|
||||
typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]);
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM,
|
||||
int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k;
|
||||
if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k())
|
||||
{
|
||||
problem_size_k = problem_size.k();
|
||||
}
|
||||
else
|
||||
{
|
||||
problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size;
|
||||
}
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(
|
||||
LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = canonical_warp_idx_sync();
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = params.ptr_C_split;
|
||||
ElementC* ptr_D = params.ptr_D_split;
|
||||
|
||||
LayoutC layout_C(params.ldc[problem_idx]);
|
||||
LayoutC layout_D(params.ldd[problem_idx]);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n());
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
|
||||
iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C);
|
||||
iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k()
|
||||
+ gridDim.z * params.splitk_buffer_offsets[problem_idx]);
|
||||
|
||||
Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,125 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// We need to distinguish here, since we want volta support. It is too much effort
|
||||
// to write shared memory iterators that are probably needed for volta to function
|
||||
// properly. As a result, we allow converters both after the LDG (for volta) and after
|
||||
// the LDS for Turing+.
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Warp level Mma
|
||||
typename MmaOperator,
|
||||
/// Math operation perform by warp level operator
|
||||
typename MathOperator>
|
||||
struct SetConverters
|
||||
{
|
||||
};
|
||||
|
||||
// Dequantize after LDG, so set transforms accordingly
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
|
||||
{
|
||||
using TransformAfterLDG
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename IteratorB::Element, IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
// Dequantize after LDS, so set transforms accordingly
|
||||
|
||||
template <
|
||||
/// Iterator for B matrix in global memory
|
||||
typename IteratorB,
|
||||
/// Mma Policy
|
||||
typename MmaOperator>
|
||||
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
|
||||
{
|
||||
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
|
||||
IteratorB::Fragment::kElements>;
|
||||
|
||||
using TransformAfterLDS
|
||||
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
|
||||
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Element type for A matrix operand
|
||||
typename ElementA_,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA_,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Element type for B matrix operand
|
||||
typename ElementB_,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB_,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale_,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale_,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator_,
|
||||
/// Layout type for C and D matrix operands
|
||||
typename LayoutC_,
|
||||
/// Operator class tag
|
||||
typename OperatorClass_,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag_,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape_,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
///
|
||||
typename Enable = void>
|
||||
struct DqMma;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,302 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsMultistage;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale = IteratorScale;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
|
||||
Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Stages in GEMM
|
||||
int kStages,
|
||||
/// Operator performed by GEMM
|
||||
typename Operator_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
|
||||
kStages, Operator_, SharedMemoryClear,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|
||||
|| platform::is_same<ElementA, float_e4m3_t>::value,
|
||||
"Element A must be fp16, fp8 or bf16");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
|
||||
"Mma multistage must dequantize after ldsm");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
|
||||
? cutlass::arch::CacheOperation::Global
|
||||
: cutlass::arch::CacheOperation::Always;
|
||||
|
||||
// Define the MmaCore components
|
||||
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
|
||||
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
|
||||
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
|
||||
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,284 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
#include "cutlass_extensions/tile_interleaved_layout.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
|
||||
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
|
||||
typename Enable = void>
|
||||
struct DefaultScaleIteratorsPipelined;
|
||||
|
||||
// Fine grained iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<isFinegrained(QuantOp)>>
|
||||
{
|
||||
private:
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
using IteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
|
||||
Layout, 0, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
SmemScaleType, Layout, 0, Alignment>;
|
||||
};
|
||||
|
||||
// Per column iterators
|
||||
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
|
||||
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
|
||||
std::enable_if_t<!isFinegrained(QuantOp)>>
|
||||
{
|
||||
static_assert((MmaShape::kN % Alignment) == 0, "");
|
||||
|
||||
private:
|
||||
// ThreadMap for scale iterator
|
||||
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
|
||||
MmaShape::kN / Alignment, Alignment>;
|
||||
using SmemScaleType = half_t;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
|
||||
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
|
||||
using SmemIteratorScale
|
||||
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
|
||||
Layout, 0, IteratorScaleThreadMap, Alignment>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
|
||||
Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
// Specialization to handle column major interleave B
|
||||
template <
|
||||
/// Type for element A
|
||||
typename ElementA,
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Type for element B
|
||||
typename ElementB,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for the input scale
|
||||
typename ElementScale,
|
||||
/// Layout for the scale operand
|
||||
typename LayoutScale,
|
||||
/// Access granularity of Scales in unit of elements
|
||||
int kAlignmentScale,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Operator class tag
|
||||
typename OperatorClass,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator_>
|
||||
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
|
||||
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
|
||||
Operator_, SharedMemoryClearOption::kNone,
|
||||
typename platform::enable_if<(
|
||||
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
|
||||
{
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
|
||||
"Element A must be fp16 or bf16");
|
||||
|
||||
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
|
||||
"Element B must be uint8 or uint4");
|
||||
|
||||
using OperatorInfo = arch::DetagOperator<Operator_>;
|
||||
using Operator = typename OperatorInfo::Operator;
|
||||
|
||||
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
|
||||
using MmaCoreElementA = half_t;
|
||||
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
|
||||
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
|
||||
OperatorClass, 2, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
|
||||
|
||||
private:
|
||||
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
|
||||
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
|
||||
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
|
||||
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
|
||||
|
||||
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
|
||||
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
|
||||
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
|
||||
|
||||
using GmemIteratorShape
|
||||
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
|
||||
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
|
||||
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
|
||||
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
|
||||
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
|
||||
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
|
||||
|
||||
public:
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
|
||||
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
|
||||
|
||||
// ThreadMap for scale iterator
|
||||
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
|
||||
using IteratorScaleThreadMap
|
||||
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
|
||||
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
|
||||
|
||||
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
|
||||
OperatorInfo::QuantOp, kAlignmentScale>;
|
||||
|
||||
// Define iterators over tiles from the scale operand
|
||||
using IteratorScale = typename ScaleIterators::IteratorScale;
|
||||
|
||||
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
|
||||
|
||||
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
|
||||
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
|
||||
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,351 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
|
||||
/// (stage>=3)
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
|
||||
|
||||
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
|
||||
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
|
||||
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
|
||||
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
|
||||
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
|
||||
GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
|
||||
GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,353 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
|
||||
SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
private:
|
||||
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
|
||||
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
|
||||
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
|
||||
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
|
||||
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
|
||||
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
|
||||
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
|
||||
layout::RowMajor, typename MmaCore::MmaPolicy>;
|
||||
};
|
||||
|
||||
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
|
||||
// large tile when not enough shared mem is present to do 3+ stage
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
/// Gather operand B by using an index array
|
||||
bool GatherB>
|
||||
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
|
||||
false, SharedMemoryClear, GatherA, GatherB>
|
||||
{
|
||||
|
||||
// Define the MmaCore components
|
||||
// 3 is used on purpose here to trigger components for mma multistage
|
||||
using MmaCore =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
|
||||
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
|
||||
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
|
||||
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
|
||||
AccessTypeA, GatherA>;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
|
||||
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
|
||||
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
|
||||
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
|
||||
AccessTypeB, GatherB>;
|
||||
|
||||
// Define the threadblock-scoped multistage matrix multiply
|
||||
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
|
||||
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
|
||||
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, 2, Operator>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
|
||||
template <
|
||||
/// Layout type for A matrix operand
|
||||
typename LayoutA,
|
||||
/// Access granularity of A matrix in units of elements
|
||||
int kAlignmentA,
|
||||
/// Layout type for B matrix operand
|
||||
typename LayoutB,
|
||||
/// Access granularity of B matrix in units of elements
|
||||
int kAlignmentB,
|
||||
/// Element type for internal accumulation
|
||||
typename ElementAccumulator,
|
||||
/// Tag indicating architecture to tune for
|
||||
typename ArchTag,
|
||||
/// Threadblock-level tile size (concept: GemmShape)
|
||||
typename ThreadblockShape,
|
||||
/// Warp-level tile size (concept: GemmShape)
|
||||
typename WarpShape,
|
||||
/// Instruction-level tile size (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
///
|
||||
int kStages,
|
||||
/// Shared memory clear option
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
|
||||
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
|
||||
false, SharedMemoryClear>
|
||||
{
|
||||
|
||||
private:
|
||||
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
|
||||
|
||||
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
|
||||
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
|
||||
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
|
||||
|
||||
public:
|
||||
// Define the MmaCore components
|
||||
using MmaCore = typename Mma::MmaCore;
|
||||
|
||||
// Define iterators over tiles from the A operand
|
||||
using IteratorA = typename Mma::IteratorA;
|
||||
|
||||
// Define iterators over tiles from the B operand
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
|
||||
// Define the threadblock-scoped pipelined matrix multiply
|
||||
using ThreadblockMma = typename Mma::ThreadblockMma;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,257 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/threadblock/mma_base.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
|
||||
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
|
||||
template <typename WarpMma, int kExpansionFactor = 1>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||
int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C);
|
||||
}
|
||||
|
||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// The type of the scales
|
||||
typename ElementScale_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// The dequantizing op to be performed.
|
||||
WeightOnlyQuantOp DequantOp,
|
||||
/// Used for partial specialization,
|
||||
typename Enable = bool>
|
||||
class DqMmaBase
|
||||
{
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< Type of the scale to be loaded
|
||||
using ElementScale = ElementScale_;
|
||||
|
||||
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
|
||||
|
||||
// Finegrained scales get streamed in via cp.async
|
||||
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
|
||||
// We always have scales.
|
||||
static constexpr int ScaleElementsPerStage = Shape::kN;
|
||||
// We sometimes have a bias
|
||||
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Shape describing the overall GEMM computed from shared memory
|
||||
/// by each warp.
|
||||
using WarpGemm = typename Policy::Operator::Shape;
|
||||
|
||||
/// Shape describing the number of warps filling the CTA
|
||||
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
|
||||
|
||||
/// Number of warp-level GEMM operations
|
||||
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
|
||||
|
||||
static constexpr int kNumKIterationsPerWarpBLoad
|
||||
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
|
||||
|
||||
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
|
||||
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
|
||||
|
||||
//
|
||||
// Nested structs
|
||||
//
|
||||
|
||||
/// Shared storage object needed by threadblock-scoped GEMM
|
||||
class SharedStorage
|
||||
{
|
||||
public:
|
||||
//
|
||||
// Type definitions
|
||||
//
|
||||
|
||||
/// Shape of the A matrix operand in shared memory
|
||||
using ShapeA
|
||||
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
|
||||
|
||||
/// Shape of the B matrix operand in shared memory
|
||||
using ShapeB
|
||||
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
|
||||
|
||||
/// Shape of the shared memory buffer for the scales for the B matrix.
|
||||
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
|
||||
/// Shape of the shared memory buffer for the biases of the B matrix.
|
||||
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
|
||||
|
||||
public:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Buffer for A operand
|
||||
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
|
||||
|
||||
/// Buffer for B operand
|
||||
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
|
||||
|
||||
/// Buffer to hold scales for threadblock
|
||||
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Returns a layout object for the A matrix
|
||||
CUTLASS_DEVICE
|
||||
static typename Operator::LayoutA LayoutA()
|
||||
{
|
||||
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a layout object for the B matrix
|
||||
CUTLASS_HOST_DEVICE
|
||||
static typename Operator::LayoutB LayoutB()
|
||||
{
|
||||
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the A operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefA operand_A_ref()
|
||||
{
|
||||
return TensorRefA{operand_A.data(), LayoutA()};
|
||||
}
|
||||
|
||||
/// Returns a TensorRef to the B operand
|
||||
CUTLASS_HOST_DEVICE
|
||||
TensorRefB operand_B_ref()
|
||||
{
|
||||
return TensorRefB{operand_B.data(), LayoutB()};
|
||||
}
|
||||
};
|
||||
|
||||
protected:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to load a warp-scoped tile of A operand from shared memory
|
||||
typename Operator::IteratorA warp_tile_iterator_A_;
|
||||
|
||||
/// Iterator to load a warp-scoped tile of B operand from shared memory
|
||||
typename Operator::IteratorB warp_tile_iterator_B_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaBase(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
|
||||
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,110 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaMultistage;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
|
||||
@@ -0,0 +1,753 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applied immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
/// The group size for quantization
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1)
|
||||
{
|
||||
static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1.");
|
||||
|
||||
typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale();
|
||||
typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
typename IteratorScale::AccessType* smem_scale_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_scale());
|
||||
typename IteratorScale::AccessType* smem_zero_ptr
|
||||
= reinterpret_cast<typename IteratorScale::AccessType*>(this->smem_iterator_scale_.get_zero());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorScale::Element>::value * IteratorScale::kAlignment / 8;
|
||||
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations);
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
typename Dequantizer::FragmentZero warp_frag_zeros;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
// if((threadIdx.x||threadIdx.y||threadIdx.z)==0){
|
||||
// uint32_t* frag_b_reg_ptr = reinterpret_cast<uint32_t*>(&warp_frag_B[0]);
|
||||
// printf("#### warp_frag_b_load [0] bid:%d-%d-%d,"
|
||||
// " frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n",
|
||||
// blockIdx.x,blockIdx.y,blockIdx.z,
|
||||
// frag_b_reg_ptr[0],
|
||||
// frag_b_reg_ptr[1],
|
||||
// frag_b_reg_ptr[2],
|
||||
// frag_b_reg_ptr[3],
|
||||
// frag_b_reg_ptr[4],
|
||||
// frag_b_reg_ptr[5],
|
||||
// frag_b_reg_ptr[6],
|
||||
// frag_b_reg_ptr[7]
|
||||
// );
|
||||
// }
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
|
||||
|
||||
// if((threadIdx.x||threadIdx.y||threadIdx.z)==0){
|
||||
// uint32_t* frag_b_reg_ptr = reinterpret_cast<uint32_t*>(&warp_frag_B[(warp_tileB_k_load_offset) % 2]);
|
||||
// uint32_t* converted_frag_B_reg_ptr = reinterpret_cast<uint32_t*>(&converted_frag_B);
|
||||
// printf("#### after lds_converter bid:%d-%d-%d"
|
||||
// " frag_b_reg_ptr[%d]:%x-%x-%x-%x-%x-%x-%x-%x"
|
||||
// " converted_frag_b_reg_ptr:%x-%x-%x-%x-%x-%x-%x-%x \n",
|
||||
// blockIdx.x,blockIdx.y,blockIdx.z,
|
||||
// ((warp_tileB_k_load_offset) % 2),
|
||||
// frag_b_reg_ptr[0],
|
||||
// frag_b_reg_ptr[1],
|
||||
// frag_b_reg_ptr[2],
|
||||
// frag_b_reg_ptr[3],
|
||||
// frag_b_reg_ptr[4],
|
||||
// frag_b_reg_ptr[5],
|
||||
// frag_b_reg_ptr[6],
|
||||
// frag_b_reg_ptr[7],
|
||||
// converted_frag_B_reg_ptr[0],
|
||||
// converted_frag_B_reg_ptr[1],
|
||||
// converted_frag_B_reg_ptr[2],
|
||||
// converted_frag_B_reg_ptr[3],
|
||||
// converted_frag_B_reg_ptr[4],
|
||||
// converted_frag_B_reg_ptr[5],
|
||||
// converted_frag_B_reg_ptr[6],
|
||||
// converted_frag_B_reg_ptr[7]
|
||||
// );
|
||||
// }
|
||||
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// This is the first group of a given stage, so we issue the loads for the B scales immediately.
|
||||
if (group_start_iteration_B == 0)
|
||||
{
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_scale.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the scale needed for the next tile iteration.
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros);
|
||||
// Update internal pointer to set of scales in shared memory.
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,647 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/arch/memory.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
|
||||
/// instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Cache operation for operand B
|
||||
cutlass::arch::CacheOperation::Kind CacheOpB,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear>
|
||||
class DqMmaMultistage<Shape_, IteratorA_, SmemIteratorA_, CacheOpA, IteratorB_, SmemIteratorB_, CacheOpB,
|
||||
IteratorScale_, SmemIteratorScale_, ElementC_, LayoutC_, Policy_, Stages, TransformBAfterLDS_, QuantOp_,
|
||||
SharedMemoryClear, std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages, QuantOp_>;
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
///< Iterates over tiles of A operand in global memory
|
||||
using IteratorA = IteratorA_;
|
||||
///< Iterates over tiles of B operand in global memory
|
||||
using IteratorB = IteratorB_;
|
||||
///< Data type of accumulator matrix
|
||||
using ElementC = ElementC_;
|
||||
///< Layout of accumulator matrix
|
||||
using LayoutC = LayoutC_;
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
|
||||
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
|
||||
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Minimum architecture is Sm80 to support cp.async
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale,
|
||||
LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
/// Internal structure exposed for introspection.
|
||||
struct Detail
|
||||
{
|
||||
|
||||
static_assert(Base::kWarpGemmIterations > 1,
|
||||
"The pipelined structure requires at least two warp-level "
|
||||
"GEMM operations.");
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand A
|
||||
static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of cp.async instructions to load one stage of operand B
|
||||
static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;
|
||||
|
||||
/// Number of stages
|
||||
static int const kStages = Stages;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand A
|
||||
static int const kAccessesPerGroupA
|
||||
= (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
|
||||
/// Number of cp.async instructions to load on group of operand B
|
||||
static int const kAccessesPerGroupB
|
||||
= (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
|
||||
};
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaMultistage(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance(
|
||||
IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
|
||||
{
|
||||
iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
|
||||
this->smem_iterator_A_.set_iteration_index(group_start_A);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j)
|
||||
{
|
||||
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_A.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
|
||||
}
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
|
||||
this->smem_iterator_B_.set_iteration_index(group_start_B);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j)
|
||||
{
|
||||
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
auto gmem_ptr = iterator_B.get();
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
else
|
||||
{
|
||||
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
|
||||
}
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
///< problem size of GEMM
|
||||
int gemm_k_iterations,
|
||||
///< destination accumulator tile
|
||||
FragmentC& accum,
|
||||
///< iterator over A operand in global memory
|
||||
IteratorA iterator_A,
|
||||
///< iterator over B operand in global memory
|
||||
IteratorB iterator_B,
|
||||
///< iterator over scale operand in global memory
|
||||
IteratorScale iterator_scale,
|
||||
///< initial value of accumulator
|
||||
FragmentC const& src_accum)
|
||||
{
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
// NOTE - switch to ldg.sts
|
||||
// Issue this first, so cp.async.commit_group will commit this load as well.
|
||||
// Note: we do not commit here and this load will commit in the same group as
|
||||
// the first load of A.
|
||||
FragmentScale tb_frag_scales;
|
||||
tb_frag_scales.clear();
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
this->smem_iterator_scale_.store(tb_frag_scales);
|
||||
|
||||
// Issue several complete stages
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations)
|
||||
{
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
iterator_A.set_iteration_index(0);
|
||||
this->smem_iterator_A_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
|
||||
* IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;
|
||||
|
||||
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
|
||||
dst_ptr + v, iterator_A.get(), iterator_A.valid());
|
||||
|
||||
++iterator_A;
|
||||
}
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
}
|
||||
|
||||
iterator_B.set_iteration_index(0);
|
||||
this->smem_iterator_B_.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v)
|
||||
{
|
||||
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
|
||||
* IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;
|
||||
|
||||
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
|
||||
dst_ptr + v, iterator_B.get(), iterator_B.valid());
|
||||
|
||||
++iterator_B;
|
||||
}
|
||||
|
||||
++this->smem_iterator_B_;
|
||||
}
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Defines the boundary of a stage of cp.async.
|
||||
cutlass::arch::cp_async_fence();
|
||||
}
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
//
|
||||
// Clear the remaining tiles of SMEM. This is a functional requirement for some kernels
|
||||
// so that all accumulator elements outside the GEMM footprint are zero.
|
||||
//
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage)
|
||||
{
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
|
||||
|
||||
typename IteratorA::AccessType zero_A;
|
||||
zero_A.clear();
|
||||
|
||||
last_smem_iterator_A.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand A
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j)
|
||||
{
|
||||
|
||||
typename IteratorA::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorA::AccessType*>(last_smem_iterator_A.get());
|
||||
|
||||
*dst_ptr = zero_A;
|
||||
|
||||
++last_smem_iterator_A;
|
||||
}
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
|
||||
typename IteratorB::AccessType zero_B;
|
||||
|
||||
zero_B.clear();
|
||||
last_smem_iterator_B.set_iteration_index(0);
|
||||
|
||||
// Async Copy for operand B
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j)
|
||||
{
|
||||
|
||||
typename IteratorB::AccessType* dst_ptr
|
||||
= reinterpret_cast<typename IteratorB::AccessType*>(last_smem_iterator_B.get());
|
||||
|
||||
*dst_ptr = zero_B;
|
||||
|
||||
++last_smem_iterator_B;
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
|
||||
cutlass::arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math
|
||||
// instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
typename Dequantizer::FragmentScale warp_frag_scales;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
|
||||
int smem_write_stage_idx = Base::kStages - 1;
|
||||
int smem_read_stage_idx = 0;
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > (-Base::kStages + 1);)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
// Computes a warp-level GEMM on data held in shared memory
|
||||
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if
|
||||
// this is the last group as the case may be.
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
|
||||
using FragmentOperandB = cutlass::Array<ElementA, Operator::FragmentB::kElements>;
|
||||
constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements;
|
||||
static_assert(ConversionVectorWidth == FragmentOperandB::kElements);
|
||||
|
||||
using Converter
|
||||
= cutlass::NumericArrayConverter<ElementA, ElementScale, ConversionVectorWidth, RoundStyle>;
|
||||
|
||||
FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B);
|
||||
run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum,
|
||||
warp_tileB_k_compute_offset);
|
||||
|
||||
// Issue global->shared copies for the this stage
|
||||
if (warp_mma_k < Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
|
||||
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
}
|
||||
|
||||
if (warp_mma_k + 2 == Base::kWarpGemmIterations)
|
||||
{
|
||||
int group_start_iteration_A, group_start_iteration_B;
|
||||
group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
|
||||
group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
|
||||
|
||||
copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B);
|
||||
|
||||
// Inserts a memory fence between stages of cp.async instructions.
|
||||
cutlass::arch::cp_async_fence();
|
||||
|
||||
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 -
|
||||
// #committed)
|
||||
arch::cp_async_wait<Base::kStages - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// Move to the next stage
|
||||
iterator_A.add_tile_offset({0, 1});
|
||||
iterator_B.add_tile_offset({1, 0});
|
||||
|
||||
this->smem_iterator_A_.add_tile_offset({0, 1});
|
||||
this->smem_iterator_B_.add_tile_offset({1, 0});
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the
|
||||
// circular buffer in shared memory
|
||||
if (smem_write_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
smem_write_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_write_stage_idx;
|
||||
}
|
||||
|
||||
if (smem_read_stage_idx == (Base::kStages - 1))
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
smem_read_stage_idx = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
++smem_read_stage_idx;
|
||||
}
|
||||
|
||||
--gemm_k_iterations;
|
||||
iterator_A.clear_mask(gemm_k_iterations == 0);
|
||||
iterator_B.clear_mask(gemm_k_iterations == 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill)
|
||||
{
|
||||
// commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,106 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Data type for the scales
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Data type of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = void>
|
||||
class DqMmaPipelined;
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
|
||||
@@ -0,0 +1,486 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
|
||||
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
using WarpFragmentZero = typename Dequantizer::FragmentZero;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< The group size for quantization
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
|
||||
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_scales_and_advance(IteratorScale& iterator_scale)
|
||||
{
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
FragmentScale tb_frag_scales;
|
||||
FragmentScale tb_frag_zeros;
|
||||
tb_frag_scales.clear();
|
||||
tb_frag_zeros.clear();
|
||||
|
||||
TransformScale transformScale;
|
||||
|
||||
using FragmentElement = typename FragmentScale::Element;
|
||||
|
||||
auto gmem_scale_ptr = iterator_scale.get_scale();
|
||||
auto gmem_zero_ptr = iterator_scale.get_zero();
|
||||
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
|
||||
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
|
||||
}
|
||||
|
||||
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
|
||||
typename TransformScale::result_type tb_frag_zeros_fp16;
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
|
||||
|
||||
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
|
||||
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
|
||||
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
|
||||
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
|
||||
|
||||
if (iterator_scale.valid())
|
||||
{
|
||||
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
|
||||
|
||||
if (gmem_zero_ptr != nullptr)
|
||||
{
|
||||
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
|
||||
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
if (iterator_scale.group_size_ == 64)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if (iterator_scale.group_size_ == 128)
|
||||
{
|
||||
if constexpr (Shape::kK == 128)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
else if constexpr (Shape::kK == 64)
|
||||
{
|
||||
if (iterator_scale.row_groupsize64_ & 0x1)
|
||||
{
|
||||
iterator_scale.add_tile_offset({1, 0});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
|
||||
}
|
||||
}
|
||||
|
||||
iterator_scale.row_groupsize64_++;
|
||||
|
||||
this->smem_iterator_scale_.add_tile_offset({1, 0});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
WarpFragmentZero warp_frag_zero;
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
copy_scales_and_advance(iterator_scale);
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_scale.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
|
||||
// Load the scales needed for the next tile iteration
|
||||
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
|
||||
// Update internal pointer to the set of scales in shared memory
|
||||
warp_dequantizer_.add_pointer_offset(Shape::kN);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,399 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
|
||||
#include "cutlass_extensions/interleaved_numeric_conversion.h"
|
||||
|
||||
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Iterates over tiles of A operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorA_,
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
|
||||
typename IteratorB_,
|
||||
/// Iterates over tiles of B operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorB_,
|
||||
/// Iterators over scales in global memory
|
||||
typename IteratorScale_,
|
||||
/// Iterators over scales in shared memory
|
||||
typename SmemIteratorScale_,
|
||||
/// Data type of accumulator matrix
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator matrix
|
||||
typename LayoutC_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Converter for B matrix applied immediately after the LDG (before STS)
|
||||
typename TransformBAfterLDG_,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
|
||||
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
|
||||
std::enable_if_t<!isFinegrained(QuantOp_)>>
|
||||
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
|
||||
{
|
||||
public:
|
||||
///< Base class
|
||||
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
|
||||
|
||||
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
|
||||
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
|
||||
using ElementC = ElementC_; ///< Data type of accumulator matrix
|
||||
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
|
||||
using Policy = Policy_; ///< Policy describing tuning details
|
||||
|
||||
using IteratorScale = IteratorScale_;
|
||||
using ElementScale = typename IteratorScale::Element;
|
||||
using LayoutScale = typename IteratorScale::Layout;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
using SmemIteratorB = SmemIteratorB_;
|
||||
using SmemIteratorScale = SmemIteratorScale_;
|
||||
|
||||
using TransformBAfterLDG = TransformBAfterLDG_;
|
||||
using TransformBAfterLDS = TransformBAfterLDS_;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
//
|
||||
// Dependent types
|
||||
//
|
||||
|
||||
/// Fragment of operand A loaded from global memory
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Fragment of operand B loaded from global memory
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Fragment of operand Scale loaded from global memory;
|
||||
using FragmentScale = typename IteratorScale::Fragment;
|
||||
|
||||
/// Fragment of accumulator tile
|
||||
using FragmentC = typename Policy::Operator::FragmentC;
|
||||
|
||||
/// Warp-level Mma
|
||||
using Operator = typename Policy::Operator;
|
||||
|
||||
/// Obtain the arch tag from the warp-level operator
|
||||
using ArchTag = typename Policy::Operator::ArchTag;
|
||||
|
||||
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
|
||||
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = Operator::kTransformA;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = Operator::kTransformB;
|
||||
|
||||
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
|
||||
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
|
||||
|
||||
private:
|
||||
using WarpFragmentA = typename Operator::FragmentA;
|
||||
using WarpFragmentB = typename Operator::FragmentB;
|
||||
Dequantizer warp_dequantizer_;
|
||||
|
||||
using ElementA = typename IteratorA::Element;
|
||||
using ElementB = typename IteratorB::Element;
|
||||
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
|
||||
|
||||
static constexpr bool RequiresTileInterleave
|
||||
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
|
||||
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
|
||||
"Layout K must match threadblockK");
|
||||
|
||||
protected:
|
||||
/// Iterator to write threadblock-scoped tile of A operand to shared memory
|
||||
SmemIteratorA smem_iterator_A_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of B operand to shared memory
|
||||
SmemIteratorB smem_iterator_B_;
|
||||
|
||||
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
|
||||
SmemIteratorScale smem_iterator_scale_;
|
||||
|
||||
public:
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx ///< ID of each thread within a warp
|
||||
)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
|
||||
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
|
||||
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
|
||||
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
|
||||
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
|
||||
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
|
||||
{
|
||||
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
// _n: the warp's position within the threadblock along the N dimension
|
||||
// _k: the warp's position within the threadblock along the K dimension
|
||||
|
||||
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
|
||||
|
||||
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
|
||||
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
|
||||
|
||||
// Add per-warp offsets in units of warp-level tiles
|
||||
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
|
||||
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
|
||||
}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
|
||||
FragmentC& accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
|
||||
FragmentC const& src_accum)
|
||||
{ ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
//
|
||||
TransformBAfterLDG ldg_converter;
|
||||
TransformBAfterLDS lds_converter;
|
||||
|
||||
using TransformA
|
||||
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
|
||||
|
||||
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
|
||||
typename FragmentScale::Element, FragmentScale::kElements>;
|
||||
|
||||
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
|
||||
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
|
||||
TransformA transformA;
|
||||
TransformScale transformScale;
|
||||
|
||||
// Perform accumulation in the 'd' output operand
|
||||
accum = src_accum;
|
||||
|
||||
FragmentA tb_frag_A;
|
||||
FragmentB tb_frag_B;
|
||||
FragmentScale tb_frag_scales;
|
||||
|
||||
using WarpFragmentScale = typename Dequantizer::FragmentScale;
|
||||
WarpFragmentScale warp_frag_scales;
|
||||
|
||||
tb_frag_A.clear();
|
||||
tb_frag_B.clear();
|
||||
tb_frag_scales.clear();
|
||||
|
||||
// The last kblock is loaded in the prolog
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
iterator_scale.load(tb_frag_scales);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
warp_dequantizer_.load(warp_frag_scales);
|
||||
|
||||
// Pair of fragments used to overlap shared memory loads and math instructions
|
||||
WarpFragmentA warp_frag_A[2];
|
||||
WarpFragmentB warp_frag_B[2];
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index(0);
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(0);
|
||||
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
|
||||
|
||||
++this->warp_tile_iterator_A_;
|
||||
++this->warp_tile_iterator_B_;
|
||||
|
||||
Operator warp_mma;
|
||||
|
||||
int smem_write_stage_idx = 1;
|
||||
|
||||
// Avoid reading out of bounds
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 1);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
|
||||
// shared memory loads (which have the tighest latency requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
//
|
||||
|
||||
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
|
||||
CUTLASS_GEMM_LOOP
|
||||
for (; gemm_k_iterations > 0; --gemm_k_iterations)
|
||||
{
|
||||
//
|
||||
// Loop over GEMM K dimension
|
||||
//
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
|
||||
{
|
||||
|
||||
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
|
||||
// as the case may be.
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1)
|
||||
{
|
||||
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_A_.store(transformA(tb_frag_A));
|
||||
|
||||
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
|
||||
|
||||
__syncthreads();
|
||||
|
||||
++this->smem_iterator_A_;
|
||||
++this->smem_iterator_B_;
|
||||
|
||||
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
|
||||
if (smem_write_stage_idx == 1)
|
||||
{
|
||||
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
|
||||
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
this->warp_tile_iterator_A_.add_tile_offset(
|
||||
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
|
||||
this->warp_tile_iterator_B_.add_tile_offset(
|
||||
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
|
||||
}
|
||||
|
||||
smem_write_stage_idx ^= 1;
|
||||
}
|
||||
|
||||
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
|
||||
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
|
||||
++this->warp_tile_iterator_B_;
|
||||
}
|
||||
|
||||
if (warp_mma_k == 0)
|
||||
{
|
||||
|
||||
iterator_A.load(tb_frag_A);
|
||||
iterator_B.load(tb_frag_B);
|
||||
|
||||
++iterator_A;
|
||||
++iterator_B;
|
||||
|
||||
// Avoid reading out of bounds if this was the last loop iteration
|
||||
iterator_A.clear_mask(gemm_k_iterations <= 2);
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 2);
|
||||
}
|
||||
|
||||
typename TransformBAfterLDS::result_type converted_frag_B
|
||||
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
|
||||
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
|
||||
run_warp_mma(
|
||||
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,107 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Default warp-level GEMM operators selected by data type, size, and layouts of operands.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||
|
||||
#include "cutlass_extensions/arch/mma.h"
|
||||
#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for m-by-n-by-kgroup
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A elements,
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Data type of B elements
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultMmaTensorOp<WarpShape_, InstructionShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
arch::OpMultiplyAddDequantizeInterleavedBToA, PartitionsK, AccumulatorsInRowMajor>
|
||||
{
|
||||
|
||||
private:
|
||||
// Shape for computing the FP16s
|
||||
using ComputeInstructionShape = InstructionShape_;
|
||||
|
||||
// Chosen so we get K=16 for int8 and K=32 for int4.
|
||||
static constexpr int LoadInstructionK = 128 / sizeof_bits<ElementB>::value;
|
||||
|
||||
// Shape for loading the narrow data type from shared memory
|
||||
using LoadInstructionShape = GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
|
||||
|
||||
public:
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::Mma<InstructionShape_, 32, ElementA, cutlass::layout::RowMajor, ElementA,
|
||||
cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, arch::OpMultiplyAdd>,
|
||||
cutlass::MatrixShape<1, 1>>;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_, ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, Policy, LoadInstructionShape, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,306 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
||||
Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/mma_sm80.h"
|
||||
#include "cutlass/arch/mma_sm89.h"
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/gemm/warp/mma.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
/// Data type of A elements
|
||||
typename ElementA_,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA_,
|
||||
/// Data type of B elements
|
||||
typename ElementB_,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB_,
|
||||
/// Element type of C matrix
|
||||
typename ElementC_,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC_,
|
||||
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||
typename Policy_,
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
typename SharedMemoryInstructionShape_,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK_ = 1,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor = false,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaTensorOpComputeBWithF16
|
||||
{
|
||||
public:
|
||||
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Data type of multiplicand A
|
||||
using ElementA = ElementA_;
|
||||
|
||||
/// Layout of multiplicand A
|
||||
using LayoutA = LayoutA_;
|
||||
|
||||
/// Data type of multiplicand B
|
||||
using ElementB = ElementB_;
|
||||
|
||||
/// Layout of multiplicand B
|
||||
using LayoutB = LayoutB_;
|
||||
|
||||
/// Data type of accumulator matrix C
|
||||
using ElementC = ElementC_;
|
||||
|
||||
/// Layout of accumulator matrix C
|
||||
using LayoutC = LayoutC_;
|
||||
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
using ArchMmaOperator = typename Policy::Operator;
|
||||
|
||||
/// Indicates math operator
|
||||
using MathOperator = typename ArchMmaOperator::Operator;
|
||||
|
||||
/// Architecture tag from underlying instruction
|
||||
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||
static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<typename ArchMmaOperator::ElementA, float_e4m3_t>::value
|
||||
&& platform::is_same<typename ArchMmaOperator::ElementB, float_e4m3_t>::value
|
||||
&& ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports underlying HMMA/QMMA");
|
||||
|
||||
static_assert(platform::is_same<ElementA, half_t>::value
|
||||
|| (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80)
|
||||
|| (platform::is_same<ElementA, float_e4m3_t>::value && ArchTag::kMinComputeCapability >= 89),
|
||||
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada");
|
||||
|
||||
/// Indicates class of matrix operator
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
|
||||
/// Shape of underlying instruction
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Instruction shape to override shared memory iterators with
|
||||
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
|
||||
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load");
|
||||
static_assert(
|
||||
SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load");
|
||||
|
||||
static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;
|
||||
|
||||
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
|
||||
|
||||
/// Complex transform on A operand
|
||||
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||
|
||||
/// Complex transform on B operand
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
|
||||
/// Number of threads participating in warp-level matrix product
|
||||
static int const kThreadCount = 32;
|
||||
|
||||
/// Number of partitions along K dimension
|
||||
static int const kPartitionsK = PartitionsK_;
|
||||
|
||||
public:
|
||||
/// Iterates over the A operand in memory
|
||||
using IteratorA
|
||||
= MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||
MatrixShape<InstructionShape::kM, InstructionShape::kK>, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for A tile
|
||||
using FragmentA = typename IteratorA::Fragment;
|
||||
|
||||
/// Storage for transformed A tile
|
||||
using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
|
||||
|
||||
/// Iterates over the B operand in memory
|
||||
using IteratorB = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB,
|
||||
LayoutB, MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>, Policy::OpDelta::kRow,
|
||||
kThreadCount, kPartitionsK>;
|
||||
|
||||
/// Storage for B tile
|
||||
using FragmentB = typename IteratorB::Fragment;
|
||||
|
||||
/// Storage for transformed B tile
|
||||
using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Number of mma operations performed
|
||||
using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;
|
||||
|
||||
public:
|
||||
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||
ArchMmaOperator mma;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Ctor
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpComputeBWithF16() {}
|
||||
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
||||
int const warp_tileB_k_offset) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
static_assert(
|
||||
TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
|
||||
"Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of "
|
||||
"B");
|
||||
|
||||
D = C;
|
||||
|
||||
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
|
||||
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
|
||||
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
|
||||
// Serpentine visitation order maximizing reuse of Rb
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
|
||||
|
||||
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB],
|
||||
ptr_D[m_serpentine + n * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
// Serpentine visitation order maximizing reuse of Ra
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int m = 0; m < MmaIterations::kRow; ++m)
|
||||
{
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n)
|
||||
{
|
||||
|
||||
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||
|
||||
int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
|
||||
if (AccumulatorsInRowMajor)
|
||||
{ // matrix B is reordered
|
||||
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB],
|
||||
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
assert(0);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,648 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/memory_sm75.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/layout/tensor.h"
|
||||
|
||||
#include "cutlass/functional.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace gemm
|
||||
{
|
||||
namespace warp
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
/// Matrix multiply operator
|
||||
typename MmaOperator_,
|
||||
/// Size of the matrix to load (concept: MatrixShape)
|
||||
typename Shape_,
|
||||
/// Operand identity
|
||||
Operand Operand,
|
||||
/// Data type of Scale elements
|
||||
typename Element_,
|
||||
/// Layout of operand
|
||||
typename Layout_,
|
||||
/// Number of threads participating in one matrix operation
|
||||
int Threads,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
///
|
||||
typename Enable = void>
|
||||
class MmaTensorOpDequantizer;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Bfloat specialization for Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, bfloat16_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 80
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = bfloat16_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
|
||||
|
||||
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
|
||||
__nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]);
|
||||
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii)
|
||||
{
|
||||
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
|
||||
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
|
||||
// numerous conversion instructions in GEMM main loop.
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for Turing & Ampere
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<MmaOperator_::ArchTag::kMinComputeCapability >= 75
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
// This is the ratio of the load instruction vs the compute instruction.
|
||||
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = half_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// We need 1 fp16 per matrix iteration in the N dimension
|
||||
static constexpr int kColsPerMmaPerThread = 1;
|
||||
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
using FragmentZero = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
pointer_zero_ = smem_zeros.data() + thread_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag, FragmentScale& zero_frag)
|
||||
{
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB
|
||||
= Array<typename FragmentDequantizedOperand::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
multiplies<ExpandedMmaOperandB> mul_op;
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
plus<ExpandedMmaOperandB> plus_op;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter]
|
||||
= plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds a pointer offset in units of elements.
|
||||
CUTLASS_DEVICE
|
||||
void add_pointer_offset(int64_t const& offset)
|
||||
{
|
||||
static_assert(sizeof(ElementScale) > 1, "");
|
||||
pointer_scale_ += offset;
|
||||
pointer_zero_ += offset;
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_scale_;
|
||||
ElementScale const* pointer_zero_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved gemm
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::RowMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
||||
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = half_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// Each 32x32x4 matmul uses 8 elements from B.
|
||||
static constexpr int ColsPerMmaTile = 32;
|
||||
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
||||
using FragmentScale = Array<ElementScale, TileNIterations * 8>;
|
||||
using AccessType = Array<ElementScale, 8>;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const base_col = lane_idx & 0xF8;
|
||||
int const thread_offset = warp_offset + base_col;
|
||||
pointer_ = smem_scales.data() + thread_offset;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter)
|
||||
{
|
||||
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
||||
scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(pointer_ + ColsPerMmaTile * tile_iter);
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
||||
|
||||
multiplies<FragmentDequantizedOperand> mul_op;
|
||||
operand_frag = mul_op(operand_frag, scale_frag);
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved gemm
|
||||
template <
|
||||
/// Underlying matrix multiply operator (concept: MmaTensorOp)
|
||||
typename MmaOperator_,
|
||||
/// Shape of the warp level matrix multiply (concept: GemmShape)
|
||||
typename Shape_,
|
||||
///
|
||||
WeightOnlyQuantOp QuantOp_>
|
||||
class MmaTensorOpDequantizer<MmaOperator_, Shape_, Operand::kB, half_t, layout::RowMajor, 32, QuantOp_,
|
||||
typename platform::enable_if<platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value
|
||||
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type>
|
||||
{
|
||||
|
||||
public:
|
||||
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape, GemmShape<32, 32, 4>>::value, "");
|
||||
|
||||
/// Mma Operator
|
||||
using MmaOperator = MmaOperator_;
|
||||
|
||||
// The architecture specific mma ooperator being used
|
||||
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
|
||||
|
||||
// Mma Instruction Shape
|
||||
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||
|
||||
/// Type of the scales
|
||||
using ElementScale = half_t;
|
||||
|
||||
/// Fragment to hold B data before Mma
|
||||
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
|
||||
|
||||
/// Warp mma shape
|
||||
using Shape = Shape_;
|
||||
|
||||
// Fragment to hold scale data to apply to B before mma
|
||||
// Each 32x32x4 matmul uses 8 elements from B.
|
||||
static constexpr int ColsPerMmaTile = 32;
|
||||
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
|
||||
using FragmentScale = Array<ElementScale, TileNIterations * 2>;
|
||||
|
||||
/// Layout of the scales in shared memory
|
||||
using Layout = layout::RowMajor;
|
||||
|
||||
/// TensorRef type for loading element from a tensor
|
||||
using TensorRef = TensorRef<ElementScale, Layout>;
|
||||
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const base_col = lane_idx & 0xF8 + lane_idx % 4;
|
||||
int const thread_offset = warp_offset + base_col;
|
||||
pointer_ = smem_scales.data() + thread_offset;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void load(FragmentScale& scale_frag)
|
||||
{
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter)
|
||||
{
|
||||
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
|
||||
// For col major B, each thread will jump 4 cols to get its next value inside
|
||||
// of the super mma.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_iter = 0; mma_iter < 2; ++mma_iter)
|
||||
{
|
||||
scale_frag[tile_iter * 2 + mma_iter] = pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
static constexpr int total_n_mmas = 2 * TileNIterations;
|
||||
static_assert(MmaOperandB::kElements * total_n_mmas == FragmentDequantizedOperand::kElements, "");
|
||||
|
||||
multiplies<MmaOperandB> mul_op;
|
||||
|
||||
MmaOperandB* operand_frag_ptr = reinterpret_cast<MmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter)
|
||||
{
|
||||
operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
ElementScale const* pointer_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
254
custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h
Normal file
254
custom_ops/gpu_ops/cutlass_extensions/gemm_configs.h
Normal file
@@ -0,0 +1,254 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape
|
||||
// in the kernel layout details when doing weight only quantization.
|
||||
enum class CutlassTileConfig
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// SiMT config
|
||||
CtaShape128x128x8_WarpShape64x64x8,
|
||||
|
||||
// TensorCore configs CTA_N = 128, CTA_K = 64
|
||||
// Warp configs for M=16
|
||||
CtaShape16x128x64_WarpShape16x32x64,
|
||||
// Warp configs for M=32
|
||||
CtaShape32x128x64_WarpShape32x32x64,
|
||||
|
||||
// Warp configs for M=64
|
||||
CtaShape64x128x64_WarpShape32x64x64,
|
||||
CtaShape64x64x128_WarpShape32x64x64,
|
||||
CtaShape64x128x64_WarpShape64x32x64,
|
||||
|
||||
// Warp configs for M=128
|
||||
CtaShape128x64x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x32x64,
|
||||
CtaShape128x128x64_WarpShape64x64x64,
|
||||
CtaShape128x128x64_WarpShape128x32x64,
|
||||
CtaShape128x256x64_WarpShape64x64x64,
|
||||
|
||||
// Warp configs for M=256
|
||||
CtaShape256x128x64_WarpShape64x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 64, CTA_K = 128
|
||||
CtaShape128x64x128_WarpShape64x32x128,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 64
|
||||
CtaShape16x256x64_WarpShape16x64x64,
|
||||
|
||||
// TensorCore config CTA_N = 256, CTA_K = 128
|
||||
CtaShape16x256x128_WarpShape16x64x128
|
||||
|
||||
};
|
||||
|
||||
enum class SplitKStyle
|
||||
{
|
||||
NO_SPLIT_K,
|
||||
SPLIT_K_SERIAL,
|
||||
STREAM_K, // Sm80+
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// CTA configs for M=64
|
||||
CtaShape64x16x128B,
|
||||
CtaShape64x32x128B,
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape128x16x128B,
|
||||
CtaShape128x32x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape256x128x128B,
|
||||
};
|
||||
|
||||
enum class MainloopScheduleType
|
||||
{
|
||||
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
|
||||
// defaults to the "legacy" main loop schedule.
|
||||
};
|
||||
|
||||
enum class EpilogueScheduleType
|
||||
{
|
||||
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
|
||||
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
|
||||
};
|
||||
|
||||
enum class ClusterShape
|
||||
{
|
||||
ClusterShape_1x1x1,
|
||||
ClusterShape_2x1x1,
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1,
|
||||
ClusterShape_1x8x1,
|
||||
ClusterShape_8x1x1
|
||||
};
|
||||
|
||||
struct CutlassGemmConfig
|
||||
{
|
||||
enum CandidateConfigTypeParam : int
|
||||
{
|
||||
NONE = 0,
|
||||
WEIGHT_ONLY = 1u << 0,
|
||||
SIMT_ONLY = 1u << 1,
|
||||
INT8_ONLY = 1u << 2,
|
||||
HOPPER = 1u << 3,
|
||||
GROUPED_GEMM = 1u << 4,
|
||||
FP8_ONLY = 1u << 5,
|
||||
};
|
||||
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
||||
int split_k_factor = -1;
|
||||
int stages = -1;
|
||||
|
||||
// config options for sm90
|
||||
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
|
||||
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
|
||||
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
bool is_sm90 = false;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
, split_k_style(split_k_style)
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
, is_sm90(false)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
, is_sm90(true)
|
||||
{
|
||||
}
|
||||
|
||||
std::string toString() const
|
||||
{
|
||||
std::stringstream tactic;
|
||||
tactic << "Cutlass GEMM Tactic";
|
||||
if (tile_config_sm90 != cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic)
|
||||
{
|
||||
assert(is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=TMA"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config_sm90
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule
|
||||
<< "\n\tepi sched: " << (int) epilogue_schedule;
|
||||
}
|
||||
else if (tile_config != cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
assert(!is_sm90 && "Invalid cutlass GEMM config");
|
||||
tactic << "\n\tstyle=compatible"
|
||||
<< "\n\ttile shape ID: " << (int) tile_config
|
||||
<< "\n\tstages: " << (int) stages
|
||||
<< "\n\tsplit_k_style: " << (int) split_k_style
|
||||
<< "\n\tsplit k: " << (int) split_k_factor;
|
||||
}
|
||||
else
|
||||
{
|
||||
tactic << "\n\tundefined";
|
||||
}
|
||||
tactic << "\n";
|
||||
return tactic.str();
|
||||
}
|
||||
|
||||
void fromString(const std::string& str) {
|
||||
std::istringstream stream(str);
|
||||
std::string line;
|
||||
|
||||
while (std::getline(stream, line)) {
|
||||
if (line.find("style=TMA") != std::string::npos) {
|
||||
is_sm90 = true;
|
||||
std::getline(stream, line);
|
||||
tile_config_sm90 = static_cast<cutlass_extensions::CutlassTileConfigSM90>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
cluster_shape = static_cast<cutlass_extensions::ClusterShape>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
mainloop_schedule = static_cast<cutlass_extensions::MainloopScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
epilogue_schedule = static_cast<cutlass_extensions::EpilogueScheduleType>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
} else if (line.find("style=compatible") != std::string::npos) {
|
||||
is_sm90 = false;
|
||||
std::getline(stream, line);
|
||||
tile_config = static_cast<cutlass_extensions::CutlassTileConfig>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
stages = std::stoi(line.substr(line.find(':') + 1));
|
||||
std::getline(stream, line);
|
||||
split_k_style = static_cast<cutlass_extensions::SplitKStyle>(std::stoi(line.substr(line.find(':') + 1)));
|
||||
std::getline(stream, line);
|
||||
split_k_factor = std::stoi(line.substr(line.find(':') + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config)
|
||||
{
|
||||
// clang-format off
|
||||
if (config.is_sm90)
|
||||
{
|
||||
out << "tile_config_sm90_enum: " << int(config.tile_config_sm90)
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape);
|
||||
}
|
||||
else
|
||||
{
|
||||
out << "tile_config_enum: " << int(config.tile_config)
|
||||
<< ", split_k_style_enum: " << int(config.split_k_style)
|
||||
<< ", split_k_factor: " << config.split_k_factor
|
||||
<< ", stages: " << config.stages;
|
||||
}
|
||||
// clang-format on
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
@@ -0,0 +1,447 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*!
|
||||
\file
|
||||
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/half.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
|
||||
// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
|
||||
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
|
||||
// This converter will uninterleave the data and subtract the bias while converting to the result type.
|
||||
template <typename T, typename S, int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<half_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
|
||||
|
||||
// Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 4>;
|
||||
using source_type = Array<uint8_t, 4>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
float fp32_intermediates[4];
|
||||
|
||||
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
|
||||
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
|
||||
|
||||
// Subtract out fp32_base + 128 to make the unsigned integer signed.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
fp32_intermediates[ii] -= 8388736.f;
|
||||
}
|
||||
|
||||
// Truncate the fp32 representation and pack up as bfloat16s.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < 2; ++ii)
|
||||
{
|
||||
bf16_result_ptr[ii]
|
||||
= __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
result.clear(); // Suppress compiler warning
|
||||
arch::device_breakpoint();
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 4;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint8_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<half_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<half_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
|
||||
{
|
||||
using result_type = Array<bfloat16_t, 8>;
|
||||
using source_type = Array<uint4b_t, 8>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
result_type result;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
// We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
|
||||
// No shift needed for first item.
|
||||
uint32_t i4s = source_i4s;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 1; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
i4s >>= sizeof_bits<typename source_type::Element>::value;
|
||||
// (i4s & 0x000f000f) | 0x43004300
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[ii])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
}
|
||||
|
||||
// This is the BF16 {-136, -136} represented as an integer.
|
||||
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
||||
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < result_type::kElements / 2; ++ii)
|
||||
{
|
||||
// Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
|
||||
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
||||
}
|
||||
#else
|
||||
// Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
|
||||
// HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
|
||||
arch::device_breakpoint();
|
||||
result.clear(); // Suppress compiler warning.
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
|
||||
{
|
||||
static constexpr int VEC_WIDTH = 8;
|
||||
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<bfloat16_t, N>;
|
||||
using source_type = Array<uint4b_t, N>;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source)
|
||||
{
|
||||
using scalar_result_type = typename result_type::Element;
|
||||
using scalar_source_type = typename source_type::Element;
|
||||
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
|
||||
convert_vector_;
|
||||
|
||||
result_type result;
|
||||
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
|
||||
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
|
||||
|
||||
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
|
||||
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / VEC_WIDTH; ++i)
|
||||
{
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s)
|
||||
{
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,66 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines new layouts needed for MoE
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/matrix_coord.h"
|
||||
#include "cutlass/pitch_linear_coord.h"
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace layout
|
||||
{
|
||||
|
||||
template <int RowsPerTile, int ColumnsInterleaved>
|
||||
struct ColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr int kRowsPerTile = RowsPerTile;
|
||||
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct IsColumnMajorTileInterleave
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <int U, int V>
|
||||
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
} // namespace layout
|
||||
} // namespace cutlass
|
||||
@@ -0,0 +1,250 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM
|
||||
quantization.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/coord.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/pitch_linear.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/predicate_vector.h"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/tensor_view.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
namespace transform
|
||||
{
|
||||
namespace threadblock
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Shape, typename Element, typename Layout, int AdvanceRank, int Alignment>
|
||||
class FineGrainedScaleZeroIterator;
|
||||
|
||||
template <typename Shape_, typename Element_, int Alignment_>
|
||||
class FineGrainedScaleZeroIterator<Shape_, Element_, layout::RowMajor, 0, Alignment_>
|
||||
{
|
||||
public:
|
||||
using Shape = Shape_;
|
||||
using Element = Element_;
|
||||
using Layout = layout::RowMajor;
|
||||
static int const kAdvanceRank = 0;
|
||||
static int const kAlignment = Alignment_;
|
||||
|
||||
static int const kAccessesPerVector = 1;
|
||||
|
||||
/// Row index of scales corresponding to the groupsize of 64
|
||||
int row_groupsize64_;
|
||||
int group_size_;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using TensorView = TensorView<Element, Layout>;
|
||||
using TensorCoord = typename Layout::TensorCoord;
|
||||
using Pointer = Element*;
|
||||
using NonConstPointer = typename platform::remove_const<Element>::type*;
|
||||
|
||||
using AccessType = AlignedArray<Element, kAlignment>;
|
||||
|
||||
using Fragment = cutlass::Array<Element, kAlignment>;
|
||||
|
||||
// For compatibility with existing iterator interface
|
||||
struct Params
|
||||
{
|
||||
LongIndex stride_ = 0;
|
||||
|
||||
/// amount (in byte) to increment pointer from first access of current tile
|
||||
/// to first access of next tile
|
||||
LongIndex inc_advance_ = 0;
|
||||
|
||||
// Default ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() {}
|
||||
|
||||
/// Construct the Params object given a pitch-linear tensor's layout
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const& layout)
|
||||
: stride_(layout.stride(0))
|
||||
{
|
||||
inc_advance_ = Shape::kRow * stride_ * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
/// Internal pointer type permits fast address arithmetic
|
||||
using BytePointer = char*;
|
||||
|
||||
private:
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters object with precomputed internal state
|
||||
Params const params_;
|
||||
|
||||
/// Internal pointer to first access of tile
|
||||
BytePointer pointer_scale_;
|
||||
BytePointer pointer_zero_;
|
||||
|
||||
bool is_valid_ = false;
|
||||
|
||||
public:
|
||||
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
||||
/// and thread ID
|
||||
CUTLASS_DEVICE
|
||||
FineGrainedScaleZeroIterator(
|
||||
///< Precomputed parameters object
|
||||
Params const& params,
|
||||
///< Pointer to start of scale tensor
|
||||
Pointer pointer_scale,
|
||||
///< Pointer to start of zero tensor
|
||||
Pointer pointer_zero,
|
||||
///< Extent of the scale and bias
|
||||
TensorCoord extent,
|
||||
///< ID of each participating thread
|
||||
int thread_id,
|
||||
///< Initial offset of threadblock
|
||||
TensorCoord const& threadblock_offset,
|
||||
///< Group size
|
||||
int group_size)
|
||||
: params_(params)
|
||||
, pointer_scale_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_scale)))
|
||||
, pointer_zero_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer_zero)))
|
||||
{
|
||||
row_groupsize64_ = threadblock_offset.row();
|
||||
group_size_ = group_size;
|
||||
|
||||
const LongIndex tb_row_byte_offset
|
||||
= threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset);
|
||||
}
|
||||
|
||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||
|
||||
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||
|
||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset);
|
||||
}
|
||||
|
||||
// For the rows, we must check that we are within the extent AND the tile to avoid extra reads on
|
||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||
// outside of the constructor.
|
||||
int const global_row = threadblock_offset.row() + thread_row;
|
||||
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
|
||||
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
bool const col_in_bounds = global_col < extent.column();
|
||||
|
||||
is_valid_ = row_in_bounds && col_in_bounds;
|
||||
}
|
||||
|
||||
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
||||
CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object
|
||||
Pointer pointer_scale, ///< Pointer to start of scale tensor
|
||||
Pointer pointer_zero, ///< Pointer to start of zero tensor
|
||||
TensorCoord extent, ///< Extent of tensor
|
||||
int thread_id, ///< ID of each participating thread
|
||||
int group_size)
|
||||
: FineGrainedScaleZeroIterator(
|
||||
params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size)
|
||||
{
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void add_tile_offset(TensorCoord const& tile_offset)
|
||||
{
|
||||
const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_;
|
||||
const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits<Element>::value / 8;
|
||||
pointer_scale_ += row_byte_offset + col_byte_offset;
|
||||
if (pointer_zero_ != nullptr)
|
||||
{
|
||||
pointer_zero_ += row_byte_offset + col_byte_offset;
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the predicate set efficiently
|
||||
CUTLASS_HOST_DEVICE void clear_mask(bool enable = true)
|
||||
{
|
||||
is_valid_ &= (!enable);
|
||||
}
|
||||
|
||||
/// Returns whether access is valid or not
|
||||
CUTLASS_HOST_DEVICE
|
||||
bool valid() const
|
||||
{
|
||||
return is_valid_;
|
||||
}
|
||||
|
||||
/// Returns a scale pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_scale() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_scale_);
|
||||
}
|
||||
|
||||
/// Returns a zero pointer
|
||||
CUTLASS_HOST_DEVICE
|
||||
AccessType* get_zero() const
|
||||
{
|
||||
return reinterpret_cast<AccessType*>(pointer_zero_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace transform
|
||||
} // namespace cutlass
|
||||
181
custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp
Normal file
181
custom_ops/gpu_ops/cutlass_extensions/util/gather_tensor.hpp
Normal file
@@ -0,0 +1,181 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
#pragma once
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/util/print.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/// Function object that applies an index to its argument
|
||||
template <class Iter>
|
||||
struct IndexedGather
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {})
|
||||
: indices_(indices)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
CUTE_HOST_DEVICE constexpr auto operator()(I i) const
|
||||
{
|
||||
return indices_[i];
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(IndexedGather const& s)
|
||||
{
|
||||
cute::print("Indexed{");
|
||||
print(s.indices_);
|
||||
print("}");
|
||||
}
|
||||
|
||||
Iter indices_;
|
||||
};
|
||||
|
||||
/// Custom stride object that applies a function followed by a stride
|
||||
template <class Func, class Stride>
|
||||
struct CustomStride
|
||||
{
|
||||
CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride)
|
||||
: func_(func)
|
||||
, stride_(stride)
|
||||
{
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
template <class I>
|
||||
CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i)
|
||||
{
|
||||
return s.func_(i) * s.stride_;
|
||||
}
|
||||
|
||||
CUTE_HOST_DEVICE friend void print(CustomStride const& s)
|
||||
{
|
||||
cute::print("Custom{");
|
||||
print(s.func_);
|
||||
cute::print(",");
|
||||
print(s.stride_);
|
||||
cute::print("}");
|
||||
}
|
||||
|
||||
template <class Div>
|
||||
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
|
||||
{
|
||||
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
|
||||
}
|
||||
|
||||
// Circumvent the requirement on make_layout that shape and stride are integral
|
||||
template <class Shape>
|
||||
CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride)
|
||||
{
|
||||
return Layout<Shape, CustomStride>(shape, stride);
|
||||
}
|
||||
|
||||
Func func_;
|
||||
Stride stride_;
|
||||
};
|
||||
|
||||
template <class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func)
|
||||
{
|
||||
// Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride
|
||||
auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
return make_layout(
|
||||
repeat_like(stride, _1{}), replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
|
||||
}
|
||||
|
||||
/// Helper function to optionally create a gather tensor
|
||||
template <class Iterator, class Shape, class Stride, class Func>
|
||||
CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func)
|
||||
{
|
||||
Layout matrix_layout = make_identity_layout(shape);
|
||||
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
|
||||
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
|
||||
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
|
||||
}
|
||||
|
||||
namespace cute
|
||||
{
|
||||
|
||||
template <int N, int I, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride)
|
||||
{
|
||||
if constexpr (is_tuple<Shape>::value)
|
||||
{
|
||||
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N, I>(s, d); });
|
||||
}
|
||||
else if constexpr (is_scaled_basis<Stride>::value)
|
||||
{
|
||||
if constexpr (Stride::mode() == I)
|
||||
{
|
||||
return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return upcast<N>(shape, stride);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
|
||||
CUTE_HOST_DEVICE constexpr auto upcast(
|
||||
ComposedLayout<Layout<OuterShape, OuterStride>, Offset, Layout<Shape, Stride>> const& layout)
|
||||
{
|
||||
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
|
||||
auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; });
|
||||
constexpr int I = decltype(idx)::value;
|
||||
|
||||
// Upcast the outer layout (works as expected)
|
||||
auto outer = upcast<N>(layout.layout_a());
|
||||
|
||||
// Upcast the accumulated offset along stride-1 mode
|
||||
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
|
||||
|
||||
// Upcast the inner layout's shape along stride-1 mode
|
||||
auto inner = upcast<N, I>(layout.layout_b().shape(), layout.layout_b().stride());
|
||||
|
||||
return composition(outer, offset, inner);
|
||||
}
|
||||
|
||||
} // namespace cute
|
||||
58
custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h
Normal file
58
custom_ops/gpu_ops/cutlass_extensions/weight_only_quant_op.h
Normal file
@@ -0,0 +1,58 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace cutlass
|
||||
{
|
||||
|
||||
enum class WeightOnlyQuantOp
|
||||
{
|
||||
UNDEFINED,
|
||||
PER_COLUMN_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_ONLY,
|
||||
FINEGRAINED_SCALE_AND_ZEROS
|
||||
};
|
||||
|
||||
constexpr bool isFinegrained(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY;
|
||||
}
|
||||
|
||||
constexpr bool hasZero(WeightOnlyQuantOp op)
|
||||
{
|
||||
return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS;
|
||||
}
|
||||
|
||||
} // namespace cutlass
|
||||
Reference in New Issue
Block a user