// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "fast_hardamard_kernel.h" #define FULL_MASK 0xffffffff struct uint8 { uint4 u; uint4 v; }; template struct BytesToType {}; template<> struct BytesToType<32> { using Type = uint8; static_assert(sizeof(Type) == 32); }; template<> struct BytesToType<16> { using Type = uint4; static_assert(sizeof(Type) == 16); }; template<> struct BytesToType<8> { using Type = uint64_t; static_assert(sizeof(Type) == 8); }; template<> struct BytesToType<4> { using Type = uint32_t; static_assert(sizeof(Type) == 4); }; template<> struct BytesToType<2> { using Type = uint16_t; static_assert(sizeof(Type) == 2); }; template<> struct BytesToType<1> { using Type = uint8_t; static_assert(sizeof(Type) == 1); }; template struct nv_type_traits { using type = T; }; template <> struct nv_type_traits { using type = half; }; template <> struct nv_type_traits { using type = __nv_bfloat16; }; template <> struct nv_type_traits { using type = int8_t; }; #define DISPATCH_SP_logN(logN, kLogN, ...) \ if (logN == 10) { \ constexpr int kLogN = 10; \ __VA_ARGS__ \ } else if (logN == 9) { \ constexpr int kLogN = 9; \ __VA_ARGS__ \ } else if (logN == 8) { \ constexpr int kLogN = 8; \ __VA_ARGS__ \ } else if (logN == 7) { \ constexpr int kLogN = 7; \ __VA_ARGS__ \ } else { \ PADDLE_THROW(phi::errors::Unimplemented("logN = %d is unsupport!", logN)); \ } #define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ if (vec_size == 16) { \ constexpr int VEC_SIZE = 16; \ __VA_ARGS__ \ } else if (vec_size == 8) { \ constexpr int VEC_SIZE = 8; \ __VA_ARGS__ \ } else if (vec_size == 4) { \ constexpr int VEC_SIZE = 4; \ __VA_ARGS__ \ } else if (vec_size == 2) { \ constexpr int VEC_SIZE = 2; \ __VA_ARGS__ \ } else if (vec_size == 1) { \ constexpr int VEC_SIZE = 1; \ __VA_ARGS__ \ } else { \ PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupport!", vec_size)); \ } #define DISPATCH_logN(logN, kLogN, ...) \ if (logN == 11) { \ constexpr int kLogN = 11; \ __VA_ARGS__ \ } else if (logN == 12) { \ constexpr int kLogN = 12; \ __VA_ARGS__ \ } else if (logN == 13) { \ constexpr int kLogN = 13; \ __VA_ARGS__ \ } else if (logN == 14) { \ constexpr int kLogN = 14; \ __VA_ARGS__ \ } else { \ PADDLE_THROW(phi::errors::Unimplemented("unsupported logN")); \ } template __device__ __forceinline__ void hadamard_mult_thread_28_transpose(T x[28][VecSize]) { // 35 T out[28]; #pragma unroll for (int vi = 0; vi < VecSize; vi++) { out[0] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; out[1] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; out[2] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; out[3] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; out[4] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; out[5] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; out[6] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; out[7] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; out[8] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; out[9] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; out[10] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; out[11] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; out[12] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi]; out[13] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; out[14] = - x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; out[15] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; out[16] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; out[17] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; out[18] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi]; out[19] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; out[20] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi]; out[21] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi]; out[22] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; out[23] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; out[24] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi]; out[25] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi]; out[26] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi]; out[27] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi]; #pragma unroll for (int i = 0; i < 28; i++) { x[i][vi] = out[i]; } } } template __device__ __forceinline__ void hadamard_mult_thread_36_transpose(T x[36][VecSize]) { // 4t T out[36]; #pragma unroll for (int vi = 0; vi < VecSize; vi++) { out[0] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; out[1] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; out[2] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; out[3] = + x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; out[4] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; out[5] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; out[6] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; out[7] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; out[8] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; out[9] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; out[10] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; out[11] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; out[12] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; out[13] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; out[14] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; out[15] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; out[16] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; out[17] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; out[18] = - x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; out[19] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; out[20] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; out[21] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; out[22] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; out[23] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; out[24] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; out[25] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; out[26] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; out[27] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; out[28] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; out[29] = + x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; out[30] = + x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; out[31] = + x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; out[32] = + x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; out[33] = + x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; out[34] = + x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; out[35] = + x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; #pragma unroll for (int i = 0; i < 36; i++) { x[i][vi] = out[i]; } } } template __device__ __forceinline__ void hadamard_mult_thread_28(T x[28]) { // 35 T out[28]; out[0] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + x[27]; out[1] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] + x[27]; out[2] = + x[0] + x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - x[27]; out[3] = + x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] + x[26] + x[27]; out[4] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] - x[26] + x[27]; out[5] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] - x[27]; out[6] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - x[27]; out[7] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + x[25] - x[26] - x[27]; out[8] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] + x[26] - x[27]; out[9] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] + x[27]; out[10] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + x[27]; out[11] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - x[27]; out[12] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] + x[25] - x[26] + x[27]; out[13] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] + x[26] - x[27]; out[14] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27]; out[15] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - x[27]; out[16] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + x[27]; out[17] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - x[27]; out[18] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27]; out[19] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + x[27]; out[20] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + x[27]; out[21] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + x[27]; out[22] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + x[27]; out[23] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; out[24] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] + x[26] - x[27]; out[25] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] + x[27]; out[26] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] - x[27]; out[27] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - x[27]; #pragma unroll for (int i = 0; i < 28; i++) { x[i] = out[i]; } } template __device__ __forceinline__ void hadamard_mult_thread_36(T x[36]) { // 4t T out[36]; out[0] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + x[25] + x[26] + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + x[33] + x[34] + x[35]; out[1] = + x[0] + x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] + x[27] + x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] + x[35]; out[2] = + x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - x[27] + x[28] + x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35]; out[3] = + x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] - x[26] - x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - x[33] + x[34] - x[35]; out[4] = + x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - x[25] + x[26] - x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - x[33] - x[34] + x[35]; out[5] = + x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - x[33] - x[34] - x[35]; out[6] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] + x[32] + x[33] - x[34] - x[35]; out[7] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] + x[33] + x[34] - x[35]; out[8] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] + x[34] + x[35]; out[9] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] + x[35]; out[10] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + x[9] + x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35]; out[11] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + x[33] - x[34] - x[35]; out[12] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - x[33] + x[34] - x[35]; out[13] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + x[33] - x[34] + x[35]; out[14] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + x[33] + x[34] - x[35]; out[15] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] + x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] - x[33] + x[34] + x[35]; out[16] = + x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + x[33] - x[34] + x[35]; out[17] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] + x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - x[25] + x[26] + x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] + x[34] - x[35]; out[18] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - x[33] - x[34] - x[35]; out[19] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + x[25] + x[26] - x[27] - x[28] + x[29] + x[30] + x[31] - x[32] + x[33] - x[34] - x[35]; out[20] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + x[25] + x[26] + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] - x[33] + x[34] - x[35]; out[21] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] + x[26] + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + x[33] - x[34] + x[35]; out[22] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + x[33] + x[34] - x[35]; out[23] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + x[33] + x[34] + x[35]; out[24] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] + x[31] - x[32] - x[33] + x[34] + x[35]; out[25] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] + x[32] - x[33] - x[34] + x[35]; out[26] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + x[33] - x[34] - x[35]; out[27] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] - x[25] - x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + x[33] + x[34] - x[35]; out[28] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + x[25] - x[26] - x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + x[33] + x[34] + x[35]; out[29] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] - x[25] + x[26] - x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - x[33] + x[34] + x[35]; out[30] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + x[25] - x[26] + x[27] - x[28] - x[29] - x[30] - x[31] - x[32] + x[33] - x[34] + x[35]; out[31] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] - x[32] - x[33] + x[34] - x[35]; out[32] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - x[33] - x[34] + x[35]; out[33] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] + x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - x[33] - x[34] - x[35]; out[34] = + x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] - x[25] - x[26] + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - x[33] - x[34] - x[35]; out[35] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] + x[25] - x[26] - x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - x[33] - x[34] - x[35]; #pragma unroll for (int i = 0; i < 36; i++) { x[i] = out[i]; } } template __device__ __forceinline__ void hadamard_mult_thread_chunk_28(T x[kNChunks][28]) { #pragma unroll for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_28(x[c]); } } template __device__ __forceinline__ void hadamard_mult_thread_chunk_36(T x[kNChunks][36]) { #pragma unroll for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_36(x[c]); } } template inline __device__ void load_input(const T *x, T x_vals[kNChunks][VecSize], int dim) { using vec_t = typename BytesToType::Type; #pragma unroll for (int c = 0; c < kNChunks; ++c) { int offset; if constexpr (UseDiagonalBlockMatrix) { static_assert(kNChunks == 1); offset = blockIdx.y * blockDim.x + threadIdx.x; } else { offset = c * blockDim.x + threadIdx.x; } if (offset * VecSize < dim) { reinterpret_cast(x_vals)[c] = reinterpret_cast(x)[offset]; } } } template __forceinline__ __device__ OutType QuantHelperFunc(const InType input, const float scale, const int round_type, const float max_bound, const float min_bound) { float quant_value = max_bound * scale * static_cast(input); if (round_type == 0) { quant_value = static_cast(rint(quant_value)); } else { quant_value = static_cast(round(quant_value)); } return static_cast(ClipFunc(quant_value, min_bound, max_bound)); } template inline __device__ void smooth_quant_store_output( OutT *out, const T *shift, const T *smooth, T out_vals[kNChunks][VecSize], const float quant_scale, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int dim) { using DstVec = AlignedVector; using Vec = AlignedVector; DstVec dst_vec; Vec shift_vec; Vec smooth_vec; #pragma unroll for (int c = 0; c < kNChunks; ++c) { int base_idx; if constexpr (UseDiagonalBlockMatrix) { base_idx = blockIdx.y * blockDim.x + threadIdx.x; } else { base_idx = c * blockDim.x + threadIdx.x; } const int idx = base_idx * VecSize; if (idx < dim) { Load(shift + idx, &shift_vec); Load(smooth + idx, &smooth_vec); #pragma unroll for (int vi = 0; vi < VecSize; ++vi) { out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; dst_vec[vi] = QuantHelperFunc( static_cast(out_vals[c][vi]), quant_scale, quant_round_type, quant_max_bound, quant_min_bound); } Store(dst_vec, out + idx); } } } template inline __device__ void quant_store_output( OutT *out, T out_vals[kNChunks][VecSize], const float quant_scale, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int dim) { using DstVec = AlignedVector; using Vec = AlignedVector; DstVec dst_vec; #pragma unroll for (int c = 0; c < kNChunks; ++c) { int base_idx; if constexpr (UseDiagonalBlockMatrix) { base_idx = blockIdx.y * blockDim.x + threadIdx.x; } else { base_idx = c * blockDim.x + threadIdx.x; } const int idx = base_idx * VecSize; if (idx < dim) { #pragma unroll for (int vi = 0; vi < VecSize; ++vi) { // out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; dst_vec[vi] = QuantHelperFunc( static_cast(out_vals[c][vi]), quant_scale, quant_round_type, quant_max_bound, quant_min_bound); } Store(dst_vec, out + idx); } } } template inline __device__ void store_output(OutT *out, T out_vals[kNChunks][VecSize], int dim) { using vec_t = typename BytesToType::Type; #pragma unroll for (int c = 0; c < kNChunks; ++c) { int offset; if constexpr (UseDiagonalBlockMatrix) { offset = blockIdx.y * blockDim.x + threadIdx.x; } else { offset = c * blockDim.x + threadIdx.x; } if (offset * VecSize < dim) { reinterpret_cast(out)[offset] = reinterpret_cast(out_vals)[c]; } } } template __device__ __forceinline__ void hadamard_mult_thread_transpose(T x[1 << kLogN][kNChunks]) { constexpr int N = 1 << kLogN; #pragma unroll for (int i = 0; i < kLogN; ++i) { const int stride = 1 << i; #pragma unroll for (int j = 0; j < N / 2; ++j) { const int lo = j & (stride - 1); const int idx = (j - lo) * 2 + lo; #pragma unroll for (int c = 0; c < kNChunks; ++c) { const T a = x[idx][c]; const T b = x[idx + stride][c]; x[idx][c] = a + b; x[idx + stride][c] = a - b; } } } } template __device__ __forceinline__ void hadamard_mult_thread(T x[kNChunks][1 << kLogN]) { constexpr int N = 1 << kLogN; #pragma unroll for (int i = 0; i < kLogN; ++i) { const int stride = 1 << i; #pragma unroll for (int j = 0; j < N / 2; ++j) { const int lo = j & (stride - 1); const int idx = (j - lo) * 2 + lo; #pragma unroll for (int c = 0; c < kNChunks; ++c) { const T a = x[c][idx]; const T b = x[c][idx + stride]; x[c][idx] = a + b; x[c][idx + stride] = a - b; } } } } template __device__ __forceinline__ void hadamard_mult_warp(T x[kNChunks][kNItems]) { constexpr int N = 1 << kLogWarpSize; int lane_id = threadIdx.x % N; #pragma unroll for (int step = kStepStart; step < kLogWarpSize; ++step) { const int lane_mask = 1 << step; const T sign = (lane_id & lane_mask) ? -1.f : 1.f; #pragma unroll for (int c = 0; c < kNChunks; ++c) { #pragma unroll for (int i = 0; i < kNItems; ++i) { T x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask); x[c][i] = sign * x[c][i] + x_val_other; } } } } template inline __device__ void exchange_smem_pre(T x_vals[kNChunks][kNElts], vec_t *smem) { // kNChunks表示整体需要多少次循环才能处理完 // kChunksPerExchange表示每次循环可以处理多少个chunk // kNExchanges表示多少次循环才能处理完所有数据 constexpr int kNThreads = kWarpSize * kNWarps; const int warp_id = threadIdx.x / kWarpSize; const int lane_id = threadIdx.x % kWarpSize; const int row_t = threadIdx.x % kNWarps; const int col_t = threadIdx.x / kNWarps; #pragma unroll for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) { // 搬多少次chunk算完所有数据 __syncthreads(); #pragma unroll for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { // 每次循环搬多少数据把smem塞满 // smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] = *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id : row_t * kWarpSize + col_t)] = *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); } __syncthreads(); #pragma unroll for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)]; *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t : warp_id * kWarpSize + lane_id)]; } } } constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; } template __global__ __launch_bounds__(kThreads) void moe_fast_hardamard_kernel(const T *x, const int64_t *expert_idx_per_token, const T *shift, const T *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, OutT *out) { using vec_t = typename BytesToType::Type; constexpr int kLogVecSize = cilog2(VecSize); constexpr int kLogWarpSize = cilog2(32); constexpr int kWarpSize = 32; constexpr int kNWarps = kThreads / kWarpSize; constexpr int kLogNWarps = cilog2(kNWarps); constexpr int kLogNChunks = cilog2(kNChunks); extern __shared__ char smem_[]; vec_t *smem_exchange = reinterpret_cast(smem_); for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { const T *x_now = x + token_id * dim; OutT *out_now = out + token_id * dim; T init_value = static_cast(0.f); T x_vals[kNChunks][VecSize] = {init_value}; load_input(x_now, x_vals, dim); #ifdef DEBUG_HARDAMARD if (blockIdx.x == 0 && threadIdx.x == 0) { for (int i = 0; i < 1; ++i) { printf("chunk_id0: %d\n", i); for (int j = 0; j < VecSize; ++j) { printf("%f ", (float)x_vals[i][j]); } printf("\n"); } } __syncthreads(); #endif hadamard_mult_thread(x_vals); #ifdef DEBUG_HARDAMARD if (blockIdx.x == 0 && threadIdx.x == 0) { for (int i = 0; i < 1; ++i) { printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); for (int j = 0; j < VecSize; ++j) { printf("%f ", (float)x_vals[i][j]); } printf("\n"); } } __syncthreads(); #endif hadamard_mult_warp(x_vals); #ifdef DEBUG_HARDAMARD if (blockIdx.x == 0 && threadIdx.x == 0) { for (int i = 0; i < 1; ++i) { printf("chunk_id2: %d\n", i); for (int j = 0; j < VecSize; ++j) { printf("%f ", (float)x_vals[i][j]); } printf("\n"); } } __syncthreads(); #endif if constexpr (kNWarps > 1) { // 先让连续的NWARPS个线程拿到其余warps上的数据 exchange_smem_pre(x_vals, smem_exchange); // 交叉计算 hadamard_mult_warp(x_vals); // 再换回来 exchange_smem_pre(x_vals, smem_exchange); } if constexpr (kNChunks > 1) { // T x_vals_transposed[VecSize][kNChunks] = {init_value}; // #pragma unroll // for (int c = 0; c < kNChunks; ++c) { // #pragma unroll // for (int i = 0; i < VecSize; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; } // } // if constexpr (kNChunks == 28) { // hadamard_mult_thread_chunk_28(x_vals_transposed); // } else if constexpr (kNChunks == 36) { // hadamard_mult_thread_chunk_36(x_vals_transposed); // } else { // constexpr int kLogNChunks = cilog2(kNChunks); // static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); // hadamard_mult_thread(x_vals_transposed); // } // #pragma unroll // for (int c = 0; c < kNChunks; ++c) { // #pragma unroll // for (int i = 0; i < VecSize; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; } // } if constexpr (kNChunks == 28) { hadamard_mult_thread_28_transpose(x_vals); } else if constexpr (kNChunks == 36) { hadamard_mult_thread_36_transpose(x_vals); } else { constexpr int kLogNChunks = cilog2(kNChunks); static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2"); hadamard_mult_thread_transpose(x_vals); } } if (quant_scales) { int64_t expert_id = expert_idx_per_token[token_id]; float quant_scale = quant_scales[expert_id]; if (shift) { smooth_quant_store_output( out_now, shift, smooth, x_vals, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, dim); } else { quant_store_output( out_now, x_vals, quant_scale, quant_round_type, quant_max_bound, quant_min_bound, dim); } } else { store_output(out_now, x_vals, dim); } } } template void MoeFastHardamardImplWrapper(const T *x, const int64_t *expert_idx_per_token, const T *shift, const T *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, OutT* out, cudaStream_t stream) { using nv_type = typename nv_type_traits::type; using out_type = typename nv_type_traits::type; constexpr int kNBytes = sizeof(T); constexpr int N = 1 << kLogN; // pad constexpr int kSmemSize = std::min(N * kNBytes, 32 * 1024); constexpr int kRounds = N * kNBytes / kSmemSize; constexpr int kChunksPerSmemSize = kSmemSize / (kThreads * VecSize * kNBytes); VLOG(1) << "real_dim: " << dim << ", N: " << N; VLOG(1) << "kNChunks: " << kNChunks; VLOG(1) << "kNBytes: " << kNBytes; VLOG(1) << "kSmemSize: " << kSmemSize; VLOG(1) << "kRounds: " << kRounds; VLOG(1) << "kChunksPerSmemSize: " << kChunksPerSmemSize; const int dev_id = 0; int sm_count; int act_blocks_per_sm; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); auto kernel = moe_fast_hardamard_kernel; cudaOccupancyMaxActiveBlocksPerMultiprocessor( &act_blocks_per_sm, kernel, kThreads, kSmemSize); const int num_blocks_per_wave = sm_count * act_blocks_per_sm; dim3 grid; grid.x = min(static_cast(num_blocks_per_wave), token_num); if constexpr (UseDiagonalBlockMatrix) { grid.y = ceil(dim / (kThreads * VecSize)); } kernel<<>>( reinterpret_cast(x), expert_idx_per_token, reinterpret_cast(shift), reinterpret_cast(smooth), quant_scales, quant_round_type, quant_max_bound, quant_min_bound, token_num, dim, reinterpret_cast(out) ); CUDA_CHECK(cudaDeviceSynchronize()); } template void MoeFastHardamardWrapper(const T *x_data, const int64_t *expert_idx_per_token, const T *shift, const T *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, OutT* out, cudaStream_t &stream) { bool FLAGS_hardamard_use_diagonal_block_matrix = true; static const char* FLAGS_hardamard_moe_block_size = std::getenv("FLAGS_hardamard_moe_block_size"); static const int32_t hardamard_moe_block_size = FLAGS_hardamard_moe_block_size != nullptr ? stoi(std::string(FLAGS_hardamard_moe_block_size)) : 512; constexpr int kThreads = 128; if (FLAGS_hardamard_use_diagonal_block_matrix) { const int VecSize = hardamard_moe_block_size / kThreads; // 128 / 128 = 1 const int logN = int(ceil(std::log2(kThreads * VecSize))); constexpr int kNChunks = 1; DISPATCH_SP_VS(VecSize, VEC_SIZE, { DISPATCH_SP_logN(logN, kLogN, { MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, shift, smooth, quant_scales, quant_round_type, quant_max_bound, quant_min_bound, token_num, dim, out, stream); })}); } else { if (!((dim / 28) & (dim / 28 - 1))) { VLOG(1) << "28 * 2^n"; const int logN = int(ceil(std::log2(dim / 28))); constexpr int kNChunks = 28; DISPATCH_SP_logN(logN, kLogN, { constexpr int VecSize = (1 << kLogN) / kThreads; MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, shift, smooth, quant_scales, quant_round_type, quant_max_bound, quant_min_bound, token_num, dim, out, stream); }); } else if (!((dim / 36) & (dim / 36 - 1))) { VLOG(1) << "36 * 2^n"; const int logN = int(ceil(std::log2(dim / 36))); constexpr int kNChunks = 36; DISPATCH_SP_logN(logN, kLogN, { constexpr int VecSize = (1 << kLogN) / kThreads; MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, shift, smooth, quant_scales, quant_round_type, quant_max_bound, quant_min_bound, token_num, dim, out, stream); }); } else { VLOG(1) << "2^n"; const int logN = int(ceil(std::log2(dim))); constexpr int VecSize = 16 / sizeof(T); DISPATCH_logN(logN, kLogN, { constexpr int kNChunks = (1 << kLogN) / (kThreads * VecSize); MoeFastHardamardImplWrapper( x_data, expert_idx_per_token, shift, smooth, quant_scales, quant_round_type, quant_max_bound, quant_min_bound, token_num, dim, out, stream); }); } } } template void MoeFastHardamardWrapper( const phi::dtype::float16 *x_data, const int64_t *expert_idx_per_token, const phi::dtype::float16 *shift, const phi::dtype::float16 *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, phi::dtype::float16 *out, cudaStream_t &stream ); template void MoeFastHardamardWrapper( const phi::dtype::float16 *x_data, const int64_t *expert_idx_per_token, const phi::dtype::float16 *shift, const phi::dtype::float16 *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, int8_t *out, cudaStream_t &stream ); template void MoeFastHardamardWrapper( const phi::dtype::bfloat16 *x_data, const int64_t *expert_idx_per_token, const phi::dtype::bfloat16 *shift, const phi::dtype::bfloat16 *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, phi::dtype::bfloat16 *out, cudaStream_t &stream ); template void MoeFastHardamardWrapper( const phi::dtype::bfloat16 *x_data, const int64_t *expert_idx_per_token, const phi::dtype::bfloat16 *shift, const phi::dtype::bfloat16 *smooth, const float* quant_scales, const int quant_round_type, const float quant_max_bound, const float quant_min_bound, const int64_t token_num, const int64_t dim, int8_t *out, cudaStream_t &stream );