mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-26 18:10:32 +08:00 
			
		
		
		
	 8f426c1690
			
		
	
	8f426c1690
	
	
	
		
			
			* Change wint2 to ColumnMajor. Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2 * Unify default_wint2x_mma. Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7 * Change wint2 to ColumnMajorTileInterleave. Change-Id: I593cbe36f991c0c5044989d65f0014087587c624 * Enable async copy for B. Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46 * Add wint2x Dequantizer * Remove TileDequanterB related codes. Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Implement Wint2ParamsAccessor to load extra quant params from global memory. Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233 * Implement FastInterleavedAndBiasedNumericArrayConverter for wint2. Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca * Use async copy for local_scale. Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893 * Check and correct the load and dequantize of weights. Change-Id: Ie8dca505b39987144964fe6407d465b3b5953790 * Change for performance tuning. Change-Id: I1da026fb1d1533a9d70350c7ba23c27e896cfc29 * Optimize the global memory access size of local_scale reading. Change-Id: I4cbe3a2ef5951723d415c2d3252ce912394beaf5 * Specialize mma_tensor_op for wint2 to enable fine-grained pipeline. Change-Id: Icbb4d48f90a41136f42d6ffff42d68de32f408da * Minor fix. Change-Id: I14d4ac9d267ee05442a3b47f00c26bee13d79e6f * optimizing dequant performance with LOP3 * optimizing dequant performance with LOP3 * Avoid redundant dequantization of local_scale and use bf16 as computing type. Change-Id: I63239ebc8f8e4a92d6281af59840ba50600b4334 * Add Multiplier and remove some logs. Change-Id: Ifa199d81e6aeb472d2247c63f85ef30213684bcd * optimizing dequant performance with LOP3 * Use __byte_perm to implement int8 to float32 conversion for performance improvement. * Use lop3 to optimize the dequantize of local_scale. Change-Id: I6189759970cb5b8dcbef769724784b8a7533b63c * Minor fix and remove some logs. Change-Id: I6279ba9926d5041093b1c6aea200acf2e4c49d46 * Fix stages for test. Change-Id: I6f7b7cac612ef2c678e9d49f5ffa60eb53d3ae29 * Fix stages for test and add clock64 to profile. Change-Id: Iffaf7324beaa910ce9ee56f47ae289de98f1a267 * Use __byte_perm to replace shift-and-or operations for faster integer merging. * Split the uint2b convert. Change-Id: I78da672ce8968e21f685285140ba546a161521b4 * Optimize convert of unscale. Change-Id: I6795da1cdf5e8ab38ddaa9836240921b5312913a * Minor optimization. Change-Id: I1800aec34c3f4621abb02658208108f54da44d88 * Optimize mma pipeline and refine codes. Change-Id: Id3075cf7b88f2813a11ccd1d3b49c62c978f36b8 * Add missing support. Change-Id: Id65b7bc2c25fbb1a5b232c6bc9fb8c9093f691a8 * Accelerate FP16 dequantization performance * Support tile shape as Xx64x64. Change-Id: Ib8fd37e1ba1d06f7d11f2956e7f1367b0a92bcac * Remove debugging codes and minor optimization. Change-Id: I6b79bd56a6e8dd823efc169967ecd3cc9a43baf4 * Fix offset bug. Change-Id: Id7aeb91e99d6f51836f2aff22187b4f79607395e * Fix typo. Change-Id: I19dde93fc1c1f7e19605905c90dc46298e203952 * Restore some codes and remove some debugging logs. Change-Id: I8d44daf82ad1c6f8174134d195e7b3fe9a3afdfb --------- Co-authored-by: baoqiwen <baoqiwen@baidu.com>
		
			
				
	
	
		
			778 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			778 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| /*!
 | |
|     \file
 | |
|     \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register
 | |
| */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "cutlass/arch/arch.h"
 | |
| #include "cutlass/array.h"
 | |
| #include "cutlass/half.h"
 | |
| #include "cutlass/numeric_types.h"
 | |
| #include "cutlass/trace.h"
 | |
| 
 | |
| namespace cutlass {
 | |
| 
 | |
| template <int lut>
 | |
| __device__ inline int lop3(int a, int b, int c) {
 | |
|   int res;
 | |
|   asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                : "=r"(res)
 | |
|                : "r"(a), "r"(b), "r"(c), "n"(lut));
 | |
|   return res;
 | |
| }
 | |
| 
 | |
| // This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low
 | |
| // bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally
 | |
| // signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned.
 | |
| // This converter will uninterleave the data and subtract the bias while converting to the result type.
 | |
| template <typename T, typename S, int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter;
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4>
 | |
| {
 | |
|     using result_type = Array<half_t, 4>;
 | |
|     using source_type = Array<uint8_t, 4>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         result_type result;
 | |
| 
 | |
|         uint32_t* h = reinterpret_cast<uint32_t*>(&result);
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         static constexpr uint32_t mask_for_elt_01 = 0x5250;
 | |
|         static constexpr uint32_t mask_for_elt_23 = 0x5351;
 | |
|         static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
 | |
|         asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
 | |
|         asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
 | |
| 
 | |
|         // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16.
 | |
|         static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N>
 | |
| {
 | |
|     static constexpr int VEC_WIDTH = 4;
 | |
|     static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
 | |
| 
 | |
|     using result_type = Array<half_t, N>;
 | |
|     using source_type = Array<uint8_t, N>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
 | |
|             convert_vector_;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
 | |
|         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / VEC_WIDTH; ++i)
 | |
|         {
 | |
|             result_ptr[i] = convert_vector_(source_ptr[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4>
 | |
| {
 | |
|     using result_type = Array<bfloat16_t, 4>;
 | |
|     using source_type = Array<uint8_t, 4>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         result_type result;
 | |
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
 | |
| 
 | |
|         uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         static constexpr uint32_t fp32_base = 0x4B000000;
 | |
|         float fp32_intermediates[4];
 | |
| 
 | |
|         // Construct FP32s, bfloat does not have enough mantissa for IADD trick
 | |
|         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
 | |
|         fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
 | |
|         fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
 | |
|         fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
 | |
|         fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
 | |
| 
 | |
|         // Subtract out fp32_base + 128 to make the unsigned integer signed.
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int ii = 0; ii < 4; ++ii)
 | |
|         {
 | |
|             fp32_intermediates[ii] -= 8388736.f;
 | |
|         }
 | |
| 
 | |
|         // Truncate the fp32 representation and pack up as bfloat16s.
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int ii = 0; ii < 2; ++ii)
 | |
|         {
 | |
|             bf16_result_ptr[ii]
 | |
|                 = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632);
 | |
|         }
 | |
| #else
 | |
|         // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
 | |
|         // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
 | |
|         result.clear(); // Suppress compiler warning
 | |
|         arch::device_breakpoint();
 | |
| #endif
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N>
 | |
| {
 | |
|     static constexpr int VEC_WIDTH = 4;
 | |
|     static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
 | |
| 
 | |
|     using result_type = Array<bfloat16_t, N>;
 | |
|     using source_type = Array<uint8_t, N>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
 | |
|             convert_vector_;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
 | |
|         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / VEC_WIDTH; ++i)
 | |
|         {
 | |
|             result_ptr[i] = convert_vector_(source_ptr[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8>
 | |
| {
 | |
|     using result_type = Array<half_t, 8>;
 | |
|     using source_type = Array<uint4b_t, 8>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         result_type result;
 | |
| 
 | |
|         uint32_t* h = reinterpret_cast<uint32_t*>(&result);
 | |
|         uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // First, we extract the i4s and construct an intermediate fp16 number.
 | |
|         static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
 | |
|         static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
 | |
|         static constexpr uint32_t TOP_MASK = 0x00f000f0;
 | |
|         static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
 | |
| 
 | |
|         // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
 | |
|         // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
 | |
|         // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
 | |
|         // elt_67 to fp16 without having to shift them to the bottom bits before hand.
 | |
| 
 | |
|         // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
 | |
|         // immediately before required.
 | |
|         const uint32_t top_i4s = i4s >> 8;
 | |
|         // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
 | |
|         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                      : "=r"(h[0])
 | |
|                      : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | |
|         // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
 | |
|         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                      : "=r"(h[1])
 | |
|                      : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | |
|         // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
 | |
|         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                      : "=r"(h[2])
 | |
|                      : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | |
|         // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
 | |
|         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                      : "=r"(h[3])
 | |
|                      : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
 | |
| 
 | |
|         // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
 | |
|         // half2 ctor. In this case, I chose performance reliability over code readability.
 | |
| 
 | |
|         // This is the half2 {1032, 1032} represented as an integer.
 | |
|         static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
 | |
|         // This is the half2 {1 / 16, 1 / 16} represented as an integer.
 | |
|         static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
 | |
|         // This is the half2 {-72, -72} represented as an integer.
 | |
|         static constexpr uint32_t NEG_72 = 0xd480d480;
 | |
| 
 | |
|         // Finally, we construct the output numbers.
 | |
|         // Convert elt_01
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
 | |
|         // Convert elt_23
 | |
|         asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
 | |
|         // Convert elt_45
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
 | |
|         // Convert elt_67
 | |
|         asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N>
 | |
| {
 | |
|     static constexpr int VEC_WIDTH = 8;
 | |
|     static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
 | |
| 
 | |
|     using result_type = Array<half_t, N>;
 | |
|     using source_type = Array<uint4b_t, N>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
 | |
|             convert_vector_;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
 | |
|         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / VEC_WIDTH; ++i)
 | |
|         {
 | |
|             result_ptr[i] = convert_vector_(source_ptr[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8>
 | |
| {
 | |
|     using result_type = Array<bfloat16_t, 8>;
 | |
|     using source_type = Array<uint4b_t, 8>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         result_type result;
 | |
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
 | |
| 
 | |
|         uint32_t* h = reinterpret_cast<uint32_t*>(&result);
 | |
|         uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // First, we extract the i4s and construct an intermediate fp16 number.
 | |
|         static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
 | |
|         static constexpr uint32_t MASK = 0x000f000f;
 | |
|         static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
 | |
| 
 | |
|         // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
 | |
|         // No shift needed for first item.
 | |
|         uint32_t i4s = source_i4s;
 | |
|         asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                      : "=r"(h[0])
 | |
|                      : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int ii = 1; ii < result_type::kElements / 2; ++ii)
 | |
|         {
 | |
|             i4s >>= sizeof_bits<typename source_type::Element>::value;
 | |
|             // (i4s & 0x000f000f) | 0x43004300
 | |
|             asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
 | |
|                          : "=r"(h[ii])
 | |
|                          : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
 | |
|         }
 | |
| 
 | |
|         // This is the BF16 {-136, -136} represented as an integer.
 | |
|         static constexpr uint32_t BF16_BIAS = 0xC308C308;
 | |
|         static constexpr uint32_t BF16_ONE = 0x3F803F80;
 | |
| 
 | |
|         // Finally, we construct the output numbers.
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int ii = 0; ii < result_type::kElements / 2; ++ii)
 | |
|         {
 | |
|             // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction
 | |
|             asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
 | |
|         }
 | |
| #else
 | |
|         // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use
 | |
|         // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters.
 | |
|         arch::device_breakpoint();
 | |
|         result.clear(); // Suppress compiler warning.
 | |
| #endif
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N>
 | |
| {
 | |
|     static constexpr int VEC_WIDTH = 8;
 | |
|     static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
 | |
| 
 | |
|     using result_type = Array<bfloat16_t, N>;
 | |
|     using source_type = Array<uint4b_t, N>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
 | |
|             convert_vector_;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = Array<scalar_result_type, VEC_WIDTH>;
 | |
|         using vec_source = Array<scalar_source_type, VEC_WIDTH>;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / VEC_WIDTH; ++i)
 | |
|         {
 | |
|             result_ptr[i] = convert_vector_(source_ptr[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s)
 | |
|     {
 | |
|         return convert(s);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint2b_t, 16>
 | |
| {
 | |
|     using result_type = Array<half_t, 16>;
 | |
|     using source_type = Array<uint2b_t, 16>;
 | |
| 
 | |
|     using ScaleComputeT = float;
 | |
|     using code_type = Array<ScaleComputeT, 4>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
 | |
|     {
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // 2^23 = 8388608
 | |
|         static constexpr uint32_t FP32_BASE = 0x4B000000;
 | |
| 
 | |
|         float fp32_intermediates[4];
 | |
|         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
 | |
|         fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
 | |
|         fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
 | |
|         fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
 | |
|         fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
 | |
| 
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
 | |
| 
 | |
|         int32_t decode_value[4];
 | |
|         ScaleComputeT new_code_zp = code_zp + 0.5f;
 | |
| 
 | |
|         decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
 | |
|         decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
 | |
|         decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
 | |
|         decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
 | |
| 
 | |
|         return convert_impl(decode_value);
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
 | |
|     {
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // 2^23 = 8388608
 | |
|         static constexpr uint32_t FP32_BASE = 0x4B000000;
 | |
| 
 | |
|         float fp32_intermediates[4];
 | |
|         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
 | |
|         fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
 | |
|         fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
 | |
|         fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
 | |
|         fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
 | |
| 
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
 | |
| 
 | |
|         int32_t decode_value[4];
 | |
| 
 | |
|         decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
 | |
|         decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
 | |
|         decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
 | |
|         decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
 | |
| 
 | |
|         return convert_impl(decode_value);
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert_impl(int32_t* decode_value)
 | |
|     {
 | |
|         result_type result;
 | |
|         static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
 | |
| 
 | |
|         static constexpr uint32_t MASK = 0x003F003F;
 | |
|         // 2^10 = 1024
 | |
|         static constexpr uint32_t EX = 0x64006400;
 | |
| 
 | |
|         uint32_t* h = reinterpret_cast<uint32_t*>(&result);
 | |
| 
 | |
|         int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
 | |
|         int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
 | |
| 
 | |
|         h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
 | |
|         h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
 | |
|         h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
 | |
|         h[3] = lop3<immLut>(q0, MASK, EX);
 | |
| 
 | |
|         h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
 | |
|         h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
 | |
|         h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
 | |
|         h[7] = lop3<immLut>(q1, MASK, EX);
 | |
| 
 | |
|         // 1024 + 32 = 1056
 | |
|         static constexpr uint32_t SUB = 0x64206420;
 | |
| 
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
 | |
| 
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
 | |
|         asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
 | |
|     {
 | |
|         return convert(s, code_scale, code_zp);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint2b_t, 16>
 | |
| {
 | |
|     using result_type = Array<bfloat16_t, 16>;
 | |
|     using source_type = Array<uint2b_t, 16>;
 | |
| 
 | |
|     using ScaleComputeT = float;
 | |
|     using code_type = Array<ScaleComputeT, 4>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, ScaleComputeT code_scale, ScaleComputeT code_zp)
 | |
|     {
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // 2^23 = 8388608
 | |
|         static constexpr uint32_t FP32_BASE = 0x4B000000;
 | |
| 
 | |
|         float fp32_intermediates[4];
 | |
|         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
 | |
|         fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
 | |
|         fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
 | |
|         fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
 | |
|         fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
 | |
| 
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
 | |
| 
 | |
|         int32_t decode_value[4];
 | |
|         ScaleComputeT new_code_zp = code_zp + 0.5f;
 | |
| 
 | |
|         decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale, new_code_zp));
 | |
|         decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale, new_code_zp));
 | |
|         decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale, new_code_zp));
 | |
|         decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale, new_code_zp));
 | |
| 
 | |
|         return convert_impl(decode_value);
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
 | |
|     {
 | |
|         uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
 | |
| 
 | |
|         // 2^23 = 8388608
 | |
|         static constexpr uint32_t FP32_BASE = 0x4B000000;
 | |
| 
 | |
|         float fp32_intermediates[4];
 | |
|         uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
 | |
|         fp32_intermediates_casted[0] = __byte_perm(i8s, FP32_BASE, 0x7650);
 | |
|         fp32_intermediates_casted[1] = __byte_perm(i8s, FP32_BASE, 0x7651);
 | |
|         fp32_intermediates_casted[2] = __byte_perm(i8s, FP32_BASE, 0x7652);
 | |
|         fp32_intermediates_casted[3] = __byte_perm(i8s, FP32_BASE, 0x7653);
 | |
| 
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[0]) : "r"(fp32_intermediates_casted[0]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[1]) : "r"(fp32_intermediates_casted[1]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[2]) : "r"(fp32_intermediates_casted[2]), "r"(FP32_BASE));
 | |
|         asm volatile("sub.f32 %0, %1, %2;\n" : "=r"(fp32_intermediates_casted[3]) : "r"(fp32_intermediates_casted[3]), "r"(FP32_BASE));
 | |
| 
 | |
|         int32_t decode_value[4];
 | |
| 
 | |
|         decode_value[0] = __float2int_rd(fmaf(fp32_intermediates[0], code_scale[0], code_zp[0] + 0.5f));
 | |
|         decode_value[1] = __float2int_rd(fmaf(fp32_intermediates[1], code_scale[1], code_zp[1] + 0.5f));
 | |
|         decode_value[2] = __float2int_rd(fmaf(fp32_intermediates[2], code_scale[2], code_zp[2] + 0.5f));
 | |
|         decode_value[3] = __float2int_rd(fmaf(fp32_intermediates[3], code_scale[3], code_zp[3] + 0.5f));
 | |
| 
 | |
|         return convert_impl(decode_value);
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert_impl(int32_t* decode_value)
 | |
|     {
 | |
|         result_type result;
 | |
| 
 | |
|         static constexpr uint32_t immLut = (0xF0 & 0xCC) | 0xAA;
 | |
|         static constexpr uint32_t MASK = 0x003F003F;
 | |
|         // 2^7 = 128
 | |
|         static constexpr uint32_t EX = 0x43004300;
 | |
| 
 | |
|         uint32_t* h = reinterpret_cast<uint32_t*>(&result);
 | |
| 
 | |
|         int32_t q0 = __byte_perm(decode_value[0], decode_value[1], 0x5410);
 | |
|         int32_t q1 = __byte_perm(decode_value[2], decode_value[3], 0x5410);
 | |
| 
 | |
|         h[0] = lop3<immLut>(q0 >> 9, MASK, EX);
 | |
|         h[1] = lop3<immLut>(q0 >> 6, MASK, EX);
 | |
|         h[2] = lop3<immLut>(q0 >> 3, MASK, EX);
 | |
|         h[3] = lop3<immLut>(q0, MASK, EX);
 | |
| 
 | |
|         h[4] = lop3<immLut>(q1 >> 9, MASK, EX);
 | |
|         h[5] = lop3<immLut>(q1 >> 6, MASK, EX);
 | |
|         h[6] = lop3<immLut>(q1 >> 3, MASK, EX);
 | |
|         h[7] = lop3<immLut>(q1, MASK, EX);
 | |
| 
 | |
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(ENABLE_BF16))
 | |
|         // 128 + 32 = 160
 | |
|         static constexpr uint32_t SUB = 0x43204320;
 | |
| 
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(SUB));
 | |
| 
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[4]) : "r"(h[4]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[5]) : "r"(h[5]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[6]) : "r"(h[6]), "r"(SUB));
 | |
|         asm volatile("sub.bf16x2 %0, %1, %2;\n" : "=r"(h[7]) : "r"(h[7]), "r"(SUB));
 | |
| #else
 | |
|         // 1.0
 | |
|         static constexpr uint32_t MUL = 0x3F803F80;
 | |
|         // -160
 | |
|         static constexpr uint32_t ADD = 0xC320C320;
 | |
| 
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MUL), "r"(ADD));
 | |
| 
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[4]) : "r"(h[4]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[5]) : "r"(h[5]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[6]) : "r"(h[6]), "r"(MUL), "r"(ADD));
 | |
|         asm volatile("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[7]) : "r"(h[7]), "r"(MUL), "r"(ADD));
 | |
| #endif
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s, ScaleComputeT code_scale, ScaleComputeT code_zp)
 | |
|     {
 | |
|         return convert(s, code_scale, code_zp);
 | |
|     }
 | |
| };
 | |
| 
 | |
| template <typename T, int N>
 | |
| struct FastInterleavedAndBiasedNumericArrayConverter<T, uint2b_t, N>
 | |
| {
 | |
|     static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
 | |
|         "T must be fp16 or bf16");
 | |
| 
 | |
|     static constexpr int kVecWidth = 16;
 | |
|     static_assert(!(N % kVecWidth), "N must be multiple of 16.");
 | |
| 
 | |
|     using result_type = Array<T, N>;
 | |
|     using source_type = Array<uint2b_t, N>;
 | |
|     using code_type = Array<float, N / kVecWidth>;
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, code_type const& code_scale, code_type const& code_zp)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>
 | |
|             convert_vector_;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = Array<scalar_result_type, kVecWidth>;
 | |
|         using vec_source = Array<scalar_source_type, kVecWidth>;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / kVecWidth; ++i)
 | |
|         {
 | |
|             result_ptr[i] = convert_vector_(source_ptr[i], code_scale[i], code_zp[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     static result_type convert(source_type const& source, Array<float, N / 4> const& code_scale, Array<float, N / 4> const& code_zp)
 | |
|     {
 | |
|         using scalar_result_type = typename result_type::Element;
 | |
|         using scalar_source_type = typename source_type::Element;
 | |
|         using Converter = FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, kVecWidth>;
 | |
| 
 | |
|         result_type result;
 | |
|         using vec_result = typename Converter::result_type;
 | |
|         using vec_source = typename Converter::source_type;
 | |
|         using vec_code = typename Converter::code_type;
 | |
| 
 | |
|         vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
 | |
|         vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
 | |
|         vec_code const* code_scale_ptr = reinterpret_cast<vec_code const*>(&code_scale);
 | |
|         vec_code const* code_zp_ptr = reinterpret_cast<vec_code const*>(&code_zp);
 | |
| 
 | |
|         CUTLASS_PRAGMA_UNROLL
 | |
|         for (int i = 0; i < N / kVecWidth; ++i)
 | |
|         {
 | |
|             result_ptr[i] = Converter::convert(source_ptr[i], code_scale_ptr[i], code_zp_ptr[i]);
 | |
|         }
 | |
| 
 | |
|         return result;
 | |
|     }
 | |
| 
 | |
|     CUTLASS_DEVICE
 | |
|     result_type operator()(source_type const& s, code_type const& code_scale, code_type const& code_zp)
 | |
|     {
 | |
|         return convert(s, code_scale, code_zp);
 | |
|     }
 | |
| };
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | |
| 
 | |
| } // namespace cutlass
 | |
| 
 | |
| /////////////////////////////////////////////////////////////////////////////////////////////////
 |