/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register */ #pragma once #include "cutlass/arch/arch.h" #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/numeric_types.h" #include "cutlass/trace.h" namespace cutlass { template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); return res; } // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. // This converter will uninterleave the data and subtract the bias while converting to the result type. template struct FastInterleavedAndBiasedNumericArrayConverter; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { result_type result; uint32_t* h = reinterpret_cast(&result); uint32_t const i8s = reinterpret_cast(source); static constexpr uint32_t mask_for_elt_01 = 0x5250; static constexpr uint32_t mask_for_elt_23 = 0x5351; static constexpr uint32_t start_byte_for_fp16 = 0x64646464; asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template struct FastInterleavedAndBiasedNumericArrayConverter { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverter convert_vector_; result_type result; using vec_result = Array; using vec_source = Array; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / VEC_WIDTH; ++i) { result_ptr[i] = convert_vector_(source_ptr[i]); } return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) uint32_t* bf16_result_ptr = reinterpret_cast(&result); uint32_t const i8s = reinterpret_cast(source); static constexpr uint32_t fp32_base = 0x4B000000; float fp32_intermediates[4]; // Construct FP32s, bfloat does not have enough mantissa for IADD trick uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); // Subtract out fp32_base + 128 to make the unsigned integer signed. CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < 4; ++ii) { fp32_intermediates[ii] -= 8388736.f; } // Truncate the fp32 representation and pack up as bfloat16s. CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < 2; ++ii) { bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); } #else // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. result.clear(); // Suppress compiler warning arch::device_breakpoint(); #endif return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template struct FastInterleavedAndBiasedNumericArrayConverter { static constexpr int VEC_WIDTH = 4; static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverter convert_vector_; result_type result; using vec_result = Array; using vec_source = Array; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / VEC_WIDTH; ++i) { result_ptr[i] = convert_vector_(source_ptr[i]); } return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { result_type result; uint32_t* h = reinterpret_cast(&result); uint32_t const i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t BOTTOM_MASK = 0x000f000f; static constexpr uint32_t TOP_MASK = 0x00f000f0; static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and // elt_67 to fp16 without having to shift them to the bottom bits before hand. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue // immediately before required. const uint32_t top_i4s = i4s >> 8; // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the // half2 ctor. In this case, I chose performance reliability over code readability. // This is the half2 {1032, 1032} represented as an integer. static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; // This is the half2 {1 / 16, 1 / 16} represented as an integer. static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; // This is the half2 {-72, -72} represented as an integer. static constexpr uint32_t NEG_72 = 0xd480d480; // Finally, we construct the output numbers. // Convert elt_01 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); // Convert elt_23 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); // Convert elt_45 asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); // Convert elt_67 asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template struct FastInterleavedAndBiasedNumericArrayConverter { static constexpr int VEC_WIDTH = 8; static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverter convert_vector_; result_type result; using vec_result = Array; using vec_source = Array; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / VEC_WIDTH; ++i) { result_ptr[i] = convert_vector_(source_ptr[i]); } return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { result_type result; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) uint32_t* h = reinterpret_cast(&result); uint32_t const source_i4s = reinterpret_cast(source); // First, we extract the i4s and construct an intermediate fp16 number. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. // No shift needed for first item. uint32_t i4s = source_i4s; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); CUTLASS_PRAGMA_UNROLL for (int ii = 1; ii < result_type::kElements / 2; ++ii) { i4s >>= sizeof_bits::value; // (i4s & 0x000f000f) | 0x43004300 asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); } // This is the BF16 {-136, -136} represented as an integer. static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; // Finally, we construct the output numbers. CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < result_type::kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); } #else // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. arch::device_breakpoint(); result.clear(); // Suppress compiler warning. #endif return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template struct FastInterleavedAndBiasedNumericArrayConverter { static constexpr int VEC_WIDTH = 8; static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); using result_type = Array; using source_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverter convert_vector_; result_type result; using vec_result = Array; using vec_source = Array; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / VEC_WIDTH; ++i) { result_ptr[i] = convert_vector_(source_ptr[i]); } return result; } CUTLASS_DEVICE result_type operator()(source_type const& s) { return convert(s); } }; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; using ScaleComputeT = float; using code_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) { uint32_t const i8s = reinterpret_cast(source); // 2^23 = 8388608 static constexpr uint32_t FP32_BASE = 0x4B000000; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); int32_t decode_value[4]; ScaleComputeT new_code_zp = code_zp + 0.5f; decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); return convert_impl(decode_value); } CUTLASS_DEVICE static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) { uint32_t const i8s = reinterpret_cast(source); // 2^23 = 8388608 static constexpr uint32_t FP32_BASE = 0x4B000000; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); int32_t decode_value[4]; decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); return convert_impl(decode_value); } CUTLASS_DEVICE static result_type convert_impl(int32_t* decode_value) { result_type result; static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; static constexpr uint32_t MASK = 0x003F003F; // 2^10 = 1024 static constexpr uint32_t EX = 0x64006400; uint32_t* h = reinterpret_cast(&result); int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); h[0] = lop3(q0 >> 9, MASK, EX); h[1] = lop3(q0 >> 6, MASK, EX); h[2] = lop3(q0 >> 3, MASK, EX); h[3] = lop3(q0, MASK, EX); h[4] = lop3(q1 >> 9, MASK, EX); h[5] = lop3(q1 >> 6, MASK, EX); h[6] = lop3(q1 >> 3, MASK, EX); h[7] = lop3(q1, MASK, EX); // 1024 + 32 = 1056 static constexpr uint32_t SUB = 0x64206420; asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); return result; } CUTLASS_DEVICE result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) { return convert(s, code_scale, code_zp); } }; template <> struct FastInterleavedAndBiasedNumericArrayConverter { using result_type = Array; using source_type = Array; using ScaleComputeT = float; using code_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp) { uint32_t const i8s = reinterpret_cast(source); // 2^23 = 8388608 static constexpr uint32_t FP32_BASE = 0x4B000000; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); int32_t decode_value[4]; ScaleComputeT new_code_zp = code_zp + 0.5f; decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp)); decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp)); decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp)); decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp)); return convert_impl(decode_value); } CUTLASS_DEVICE static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) { uint32_t const i8s = reinterpret_cast(source); // 2^23 = 8388608 static constexpr uint32_t FP32_BASE = 0x4B000000; float fp32_intermediates[4]; uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650); fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651); fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652); fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE)); asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE)); int32_t decode_value[4]; decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f)); decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f)); decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f)); decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f)); return convert_impl(decode_value); } CUTLASS_DEVICE static result_type convert_impl(int32_t* decode_value) { result_type result; static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA; static constexpr uint32_t MASK = 0x003F003F; // 2^7 = 128 static constexpr uint32_t EX = 0x43004300; uint32_t* h = reinterpret_cast(&result); int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410); int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410); h[0] = lop3(q0 >> 9, MASK, EX); h[1] = lop3(q0 >> 6, MASK, EX); h[2] = lop3(q0 >> 3, MASK, EX); h[3] = lop3(q0, MASK, EX); h[4] = lop3(q1 >> 9, MASK, EX); h[5] = lop3(q1 >> 6, MASK, EX); h[6] = lop3(q1 >> 3, MASK, EX); h[7] = lop3(q1, MASK, EX); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16)) // 128 + 32 = 160 static constexpr uint32_t SUB = 0x43204320; asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB)); asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB)); #else // 1.0 static constexpr uint32_t MUL = 0x3F803F80; // -160 static constexpr uint32_t ADD = 0xC320C320; asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD)); asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD)); #endif return result; } CUTLASS_DEVICE result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp) { return convert(s, code_scale, code_zp); } }; template struct FastInterleavedAndBiasedNumericArrayConverter { static_assert(platform::is_same::value || platform::is_same::value, "T must be fp16 or bf16"); static constexpr int kVecWidth = 16; static_assert(!(N % kVecWidth), "N must be multiple of 16."); using result_type = Array; using source_type = Array; using code_type = Array; CUTLASS_DEVICE static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; FastInterleavedAndBiasedNumericArrayConverter convert_vector_; result_type result; using vec_result = Array; using vec_source = Array; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / kVecWidth; ++i) { result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]); } return result; } CUTLASS_DEVICE static result_type convert(source_type const& source, Array const& code_scale, Array const& code_zp) { using scalar_result_type = typename result_type::Element; using scalar_source_type = typename source_type::Element; using Converter = FastInterleavedAndBiasedNumericArrayConverter; result_type result; using vec_result = typename Converter::result_type; using vec_source = typename Converter::source_type; using vec_code = typename Converter::code_type; vec_result* result_ptr = reinterpret_cast(&result); vec_source const* source_ptr = reinterpret_cast(&source); vec_code const* code_scale_ptr = reinterpret_cast(&code_scale); vec_code const* code_zp_ptr = reinterpret_cast(&code_zp); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / kVecWidth; ++i) { result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]); } return result; } CUTLASS_DEVICE result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp) { return convert(s, code_scale, code_zp); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////