From 819b2dbbae87243c3db6cf2a772600a6c99fbdb7 Mon Sep 17 00:00:00 2001 From: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:48:28 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"=E3=80=90New=20Feature=E3=80=91W4afp8?= =?UTF-8?q?=20supports=20per=20group=20quantization=20(#4272)"=20(#4854)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 93fcf7e4ec23fe57f8766ffa3702bbabd570f164. --- custom_ops/gpu_ops/cpp_extensions.cc | 1 - .../moe/fast_hardmard/fast_hardamard_kernel.h | 37 - .../fast_hardmard/fast_hardamard_kernel.hpp | 1722 ----------------- .../fast_hardamard_kernel_bf16_bf16.cu | 34 - .../fast_hardamard_kernel_bf16_fp8.cu | 34 - .../fast_hardamard_kernel_bf16_int8.cu | 33 - .../fast_hardamard_kernel_fp16_fp16.cu | 33 - .../fast_hardamard_kernel_fp16_int8.cu | 33 - custom_ops/gpu_ops/moe/fused_moe_helper.h | 272 +-- custom_ops/gpu_ops/moe/fused_moe_op.h | 828 ++++---- custom_ops/gpu_ops/moe/moe_dispatch.cu | 283 +-- .../gpu_ops/moe/moe_expert_ffn_wint2.cu | 1 - custom_ops/gpu_ops/moe/moe_ffn.cu | 154 +- .../gpu_ops/w4afp8_gemm/kernel_traits.h | 212 +- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 858 ++++---- custom_ops/gpu_ops/w4afp8_gemm/utils.hpp | 155 +- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 418 ++-- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 48 +- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 451 ++--- .../gpu_ops/w4afp8_gemm/weight_kernel.hpp | 131 -- .../w4afp8_gemm/weight_scale_kernel.hpp | 63 - .../utils/auto_gen_w4afp8_gemm_kernel.py | 109 +- .../layers/moe/fused_moe_cutlass_backend.py | 123 +- .../layers/moe/fused_moe_wint2_backend.py | 1 - .../layers/quantization/mix_quant.py | 3 - tests/operators/test_w4afp8_gemm.py | 59 +- 26 files changed, 1718 insertions(+), 4378 deletions(-) delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu delete mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu delete mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp delete mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 2a6dea831..979a94eb7 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -304,7 +304,6 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, - const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h deleted file mode 100644 index 1fc6ab5b1..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include "helper.h" - -template -void MoeFastHardamardWrapper(const T *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - OutT *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp deleted file mode 100644 index 78db01161..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp +++ /dev/null @@ -1,1722 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.h" - -#define FULL_MASK 0xffffffff - -struct uint8 { - uint4 u; - uint4 v; -}; - -template -struct BytesToType {}; - -template <> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template <> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template <> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template <> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template <> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template <> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -template -struct nv_type_traits { - using type = T; -}; - -template <> -struct nv_type_traits { - using type = half; -}; - -template <> -struct nv_type_traits { - using type = __nv_bfloat16; -}; - -template <> -struct nv_type_traits { - using type = int8_t; -}; - -#define DISPATCH_SP_logN(logN, kLogN, ...) \ - if (logN == 10) { \ - constexpr int kLogN = 10; \ - __VA_ARGS__ \ - } else if (logN == 9) { \ - constexpr int kLogN = 9; \ - __VA_ARGS__ \ - } else if (logN == 8) { \ - constexpr int kLogN = 8; \ - __VA_ARGS__ \ - } else if (logN == 7) { \ - constexpr int kLogN = 7; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW( \ - phi::errors::Unimplemented("logN = %d is unsupported!", logN)); \ - } - -#define DISPATCH_SP_VS(vec_size, VEC_SIZE, ...) \ - if (vec_size == 16) { \ - constexpr int VEC_SIZE = 16; \ - __VA_ARGS__ \ - } else if (vec_size == 8) { \ - constexpr int VEC_SIZE = 8; \ - __VA_ARGS__ \ - } else if (vec_size == 4) { \ - constexpr int VEC_SIZE = 4; \ - __VA_ARGS__ \ - } else if (vec_size == 2) { \ - constexpr int VEC_SIZE = 2; \ - __VA_ARGS__ \ - } else if (vec_size == 1) { \ - constexpr int VEC_SIZE = 1; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW(phi::errors::Unimplemented("vec_size = %d is unsupported!", \ - vec_size)); \ - } - -#define DISPATCH_logN(logN, kLogN, ...) \ - if (logN == 11) { \ - constexpr int kLogN = 11; \ - __VA_ARGS__ \ - } else if (logN == 12) { \ - constexpr int kLogN = 12; \ - __VA_ARGS__ \ - } else if (logN == 13) { \ - constexpr int kLogN = 13; \ - __VA_ARGS__ \ - } else if (logN == 14) { \ - constexpr int kLogN = 14; \ - __VA_ARGS__ \ - } else { \ - PADDLE_THROW(phi::errors::Unimplemented("unsupported logN")); \ - } - -template -__device__ __forceinline__ void hadamard_mult_thread_28_transpose( - T x[28][VecSize]) { // 35 - T out[28]; -#pragma unroll - for (int vi = 0; vi < VecSize; vi++) { - out[0] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] + - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + - x[27][vi]; - out[1] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + - x[27][vi]; - out[2] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] - - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - - x[27][vi]; - out[3] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] - - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + x[26][vi] + - x[27][vi]; - out[4] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] - - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] + - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + - x[27][vi]; - out[5] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] - - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] - - x[27][vi]; - out[6] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - - x[27][vi]; - out[7] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] + - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] - - x[27][vi]; - out[8] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] + - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] - - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - - x[27][vi]; - out[9] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + - x[27][vi]; - out[10] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] + x[27][vi]; - out[11] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi]; - out[12] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] + x[27][vi]; - out[13] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi]; - out[14] = -x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi]; - out[15] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi]; - out[16] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi]; - out[17] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi]; - out[18] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + - x[26][vi] - x[27][vi]; - out[19] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + - x[26][vi] + x[27][vi]; - out[20] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + - x[26][vi] + x[27][vi]; - out[21] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] + x[27][vi]; - out[22] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi]; - out[23] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi]; - out[24] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] + - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi]; - out[25] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi]; - out[26] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi]; - out[27] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi]; -#pragma unroll - for (int i = 0; i < 28; i++) { - x[i][vi] = out[i]; - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_36_transpose( - T x[36][VecSize]) { // 4t - T out[36]; -#pragma unroll - for (int vi = 0; vi < VecSize; vi++) { - out[0] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] + - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + x[16][vi] + - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + x[21][vi] + - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] + x[26][vi] + - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] + x[31][vi] + - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[1] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] - - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + x[16][vi] + - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - x[26][vi] + - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[2] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + x[11][vi] - - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - x[16][vi] + - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - x[26][vi] - - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[3] = +x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + x[11][vi] + - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] - - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - x[21][vi] + - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - x[26][vi] - - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - x[31][vi] - - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[4] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - x[11][vi] + - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + x[21][vi] - - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] - - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + x[31][vi] - - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[5] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - x[11][vi] - - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - x[16][vi] - - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + x[21][vi] + - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + x[31][vi] + - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[6] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - x[11][vi] - - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - x[16][vi] - - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - x[21][vi] + - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - x[31][vi] + - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[7] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + x[11][vi] - - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + x[16][vi] - - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + x[21][vi] - - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + x[26][vi] + - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - x[31][vi] - - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[8] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] - x[11][vi] + - x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + x[16][vi] + - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - x[21][vi] + - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - x[26][vi] + - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - x[31][vi] - - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[9] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + x[11][vi] - - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - x[16][vi] + - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - x[21][vi] - - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + x[26][vi] - - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + x[31][vi] - - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[10] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[11] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[12] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + - x[16][vi] - x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[13] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] - - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[14] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[15] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] + x[18][vi] - x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[16] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] - x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[17] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] + x[18][vi] + x[19][vi] + x[20][vi] - - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[18] = -x[0][vi] + x[1][vi] + x[2][vi] + x[3][vi] + x[4][vi] + x[5][vi] + - x[6][vi] + x[7][vi] + x[8][vi] + x[9][vi] + x[10][vi] + - x[11][vi] + x[12][vi] + x[13][vi] + x[14][vi] + x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[19] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[20] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] + - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] + - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[21] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] + - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] + - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[22] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] - - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[23] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] - - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[24] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] - - x[11][vi] - x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] + x[30][vi] + - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[25] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] - x[12][vi] - x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] + - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[26] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] - x[13][vi] - x[14][vi] - x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + - x[21][vi] - x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] + - x[31][vi] + x[32][vi] + x[33][vi] - x[34][vi] - x[35][vi]; - out[27] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] - x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + - x[21][vi] + x[22][vi] - x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] + x[30][vi] - - x[31][vi] + x[32][vi] + x[33][vi] + x[34][vi] - x[35][vi]; - out[28] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] - x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] + - x[31][vi] - x[32][vi] + x[33][vi] + x[34][vi] + x[35][vi]; - out[29] = +x[0][vi] - x[1][vi] + x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] - x[8][vi] + x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] - - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] - - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] + - x[26][vi] - x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - - x[31][vi] + x[32][vi] - x[33][vi] + x[34][vi] + x[35][vi]; - out[30] = +x[0][vi] - x[1][vi] - x[2][vi] + x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] - x[9][vi] + x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] - - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - - x[26][vi] + x[27][vi] - x[28][vi] - x[29][vi] - x[30][vi] - - x[31][vi] - x[32][vi] + x[33][vi] - x[34][vi] + x[35][vi]; - out[31] = +x[0][vi] - x[1][vi] - x[2][vi] - x[3][vi] + x[4][vi] + x[5][vi] - - x[6][vi] - x[7][vi] - x[8][vi] + x[9][vi] - x[10][vi] + - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] + x[20][vi] + - x[21][vi] - x[22][vi] - x[23][vi] + x[24][vi] + x[25][vi] + - x[26][vi] - x[27][vi] + x[28][vi] - x[29][vi] - x[30][vi] - - x[31][vi] - x[32][vi] - x[33][vi] + x[34][vi] - x[35][vi]; - out[32] = +x[0][vi] + x[1][vi] - x[2][vi] - x[3][vi] - x[4][vi] + x[5][vi] + - x[6][vi] - x[7][vi] - x[8][vi] - x[9][vi] + x[10][vi] - - x[11][vi] + x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] + - x[21][vi] + x[22][vi] - x[23][vi] - x[24][vi] + x[25][vi] + - x[26][vi] + x[27][vi] - x[28][vi] + x[29][vi] - x[30][vi] - - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] + x[35][vi]; - out[33] = +x[0][vi] - x[1][vi] + x[2][vi] - x[3][vi] - x[4][vi] - x[5][vi] + - x[6][vi] + x[7][vi] - x[8][vi] - x[9][vi] - x[10][vi] + - x[11][vi] - x[12][vi] + x[13][vi] + x[14][vi] - x[15][vi] + - x[16][vi] + x[17][vi] - x[18][vi] + x[19][vi] - x[20][vi] + - x[21][vi] + x[22][vi] + x[23][vi] - x[24][vi] - x[25][vi] + - x[26][vi] + x[27][vi] + x[28][vi] - x[29][vi] + x[30][vi] - - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[34] = +x[0][vi] + x[1][vi] - x[2][vi] + x[3][vi] - x[4][vi] - x[5][vi] - - x[6][vi] + x[7][vi] + x[8][vi] - x[9][vi] - x[10][vi] - - x[11][vi] + x[12][vi] - x[13][vi] + x[14][vi] + x[15][vi] - - x[16][vi] + x[17][vi] - x[18][vi] - x[19][vi] + x[20][vi] - - x[21][vi] + x[22][vi] + x[23][vi] + x[24][vi] - x[25][vi] - - x[26][vi] + x[27][vi] + x[28][vi] + x[29][vi] - x[30][vi] + - x[31][vi] - x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; - out[35] = +x[0][vi] + x[1][vi] + x[2][vi] - x[3][vi] + x[4][vi] - x[5][vi] - - x[6][vi] - x[7][vi] + x[8][vi] + x[9][vi] - x[10][vi] - - x[11][vi] - x[12][vi] + x[13][vi] - x[14][vi] + x[15][vi] + - x[16][vi] - x[17][vi] - x[18][vi] - x[19][vi] - x[20][vi] + - x[21][vi] - x[22][vi] + x[23][vi] + x[24][vi] + x[25][vi] - - x[26][vi] - x[27][vi] + x[28][vi] + x[29][vi] + x[30][vi] - - x[31][vi] + x[32][vi] - x[33][vi] - x[34][vi] - x[35][vi]; -#pragma unroll - for (int i = 0; i < 36; i++) { - x[i][vi] = out[i]; - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_28(T x[28]) { // 35 - T out[28]; - out[0] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + - x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + - x[17] + x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + - x[25] + x[26] + x[27]; - out[1] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + - x[25] - x[26] + x[27]; - out[2] = +x[0] + x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] + - x[25] + x[26] - x[27]; - out[3] = +x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - - x[9] - x[10] - x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - - x[25] + x[26] + x[27]; - out[4] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] + - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - - x[25] - x[26] + x[27]; - out[5] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + - x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] - - x[25] - x[26] - x[27]; - out[6] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] - x[8] + - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - - x[25] - x[26] - x[27]; - out[7] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] + x[8] - - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] + - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] + - x[25] - x[26] - x[27]; - out[8] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] + - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + - x[25] + x[26] - x[27]; - out[9] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] - x[15] - x[16] - - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - - x[25] + x[26] + x[27]; - out[10] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] + - x[25] - x[26] + x[27]; - out[11] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - - x[9] + x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - - x[25] + x[26] - x[27]; - out[12] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + - x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] + - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] + - x[25] - x[26] + x[27]; - out[13] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - - x[25] + x[26] - x[27]; - out[14] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + - x[9] + x[10] + x[11] + x[12] + x[13] - x[14] - x[15] - x[16] - - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - - x[25] - x[26] - x[27]; - out[15] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + - x[17] - x[18] - x[19] + x[20] + x[21] + x[22] + x[23] - x[24] - - x[25] + x[26] - x[27]; - out[16] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - - x[25] - x[26] + x[27]; - out[17] = +x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] + - x[25] - x[26] - x[27]; - out[18] = +x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] + x[16] - - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + - x[25] + x[26] - x[27]; - out[19] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + - x[9] - x[10] - x[11] - x[12] - x[13] - x[14] - x[15] - x[16] + - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + - x[25] + x[26] + x[27]; - out[20] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] - x[8] + - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + - x[25] + x[26] + x[27]; - out[21] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - - x[17] - x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - - x[25] + x[26] + x[27]; - out[22] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - - x[25] - x[26] + x[27]; - out[23] = +x[0] - x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + - x[17] + x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + - x[25] - x[26] - x[27]; - out[24] = +x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - - x[25] + x[26] - x[27]; - out[25] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] - - x[25] - x[26] + x[27]; - out[26] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - - x[17] + x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - - x[25] - x[26] - x[27]; - out[27] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - - x[17] - x[18] + x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + - x[25] - x[26] - x[27]; -#pragma unroll - for (int i = 0; i < 28; i++) { - x[i] = out[i]; - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_36(T x[36]) { // 4t - T out[36]; - out[0] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + - x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + - x[17] - x[18] + x[19] + x[20] + x[21] + x[22] + x[23] + x[24] + - x[25] + x[26] + x[27] + x[28] + x[29] + x[30] + x[31] + x[32] + - x[33] + x[34] + x[35]; - out[1] = +x[0] + x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - - x[25] - x[26] + x[27] + x[28] - x[29] - x[30] - x[31] + x[32] - - x[33] + x[34] + x[35]; - out[2] = +x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] - - x[25] - x[26] - x[27] + x[28] + x[29] - x[30] - x[31] - x[32] + - x[33] - x[34] + x[35]; - out[3] = +x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + - x[25] - x[26] - x[27] - x[28] + x[29] + x[30] - x[31] - x[32] - - x[33] + x[34] - x[35]; - out[4] = +x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] + x[23] + x[24] - - x[25] + x[26] - x[27] - x[28] - x[29] + x[30] + x[31] - x[32] - - x[33] - x[34] + x[35]; - out[5] = +x[0] + x[1] - x[2] + x[3] + x[4] + x[5] + x[6] + x[7] - x[8] + - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - - x[17] + x[18] + x[19] - x[20] + x[21] + x[22] - x[23] + x[24] + - x[25] - x[26] + x[27] - x[28] - x[29] - x[30] + x[31] + x[32] - - x[33] - x[34] - x[35]; - out[6] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] + x[7] + x[8] - - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - - x[17] + x[18] - x[19] + x[20] - x[21] + x[22] + x[23] - x[24] + - x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] + x[32] + - x[33] - x[34] - x[35]; - out[7] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] + - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - - x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] - - x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] + - x[33] + x[34] - x[35]; - out[8] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] + x[8] + - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + - x[25] - x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - - x[33] + x[34] + x[35]; - out[9] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] + - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + - x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - - x[33] - x[34] + x[35]; - out[10] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + - x[9] + x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] + x[24] - - x[25] + x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] - - x[33] - x[34] - x[35]; - out[11] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + - x[9] + x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - - x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] + - x[25] - x[26] + x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + - x[33] - x[34] - x[35]; - out[12] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - - x[9] + x[10] + x[11] + x[12] + x[13] + x[14] - x[15] + x[16] - - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - - x[25] + x[26] - x[27] + x[28] + x[29] - x[30] + x[31] + x[32] - - x[33] + x[34] - x[35]; - out[13] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + - x[9] - x[10] + x[11] + x[12] + x[13] + x[14] + x[15] - x[16] + - x[17] + x[18] - x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] - x[31] + x[32] + - x[33] - x[34] + x[35]; - out[14] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - - x[9] + x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - - x[17] + x[18] + x[19] - x[20] - x[21] - x[22] + x[23] + x[24] - - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] - x[32] + - x[33] + x[34] - x[35]; - out[15] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] + x[15] + x[16] + - x[17] + x[18] - x[19] + x[20] - x[21] - x[22] - x[23] + x[24] + - x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] - - x[33] + x[34] + x[35]; - out[16] = +x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] + x[16] + - x[17] + x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + - x[25] + x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + - x[33] - x[34] + x[35]; - out[17] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + - x[17] + x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - - x[25] + x[26] + x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + - x[33] + x[34] - x[35]; - out[18] = -x[0] + x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + - x[9] + x[10] + x[11] + x[12] + x[13] + x[14] + x[15] + x[16] + - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] - x[24] - - x[25] - x[26] - x[27] - x[28] - x[29] - x[30] - x[31] - x[32] - - x[33] - x[34] - x[35]; - out[19] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] + x[24] + - x[25] + x[26] - x[27] - x[28] + x[29] + x[30] + x[31] - x[32] + - x[33] - x[34] - x[35]; - out[20] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - - x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] + - x[25] + x[26] + x[27] - x[28] - x[29] + x[30] + x[31] + x[32] - - x[33] + x[34] - x[35]; - out[21] = +x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - - x[25] + x[26] + x[27] + x[28] - x[29] - x[30] + x[31] + x[32] + - x[33] - x[34] + x[35]; - out[22] = +x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] - x[16] + - x[17] - x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + - x[25] - x[26] + x[27] + x[28] + x[29] - x[30] - x[31] + x[32] + - x[33] + x[34] - x[35]; - out[23] = +x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] - x[16] - - x[17] - x[18] - x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - - x[25] + x[26] - x[27] + x[28] + x[29] + x[30] - x[31] - x[32] + - x[33] + x[34] + x[35]; - out[24] = +x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] + x[8] - - x[9] + x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] - - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] - x[24] - - x[25] - x[26] + x[27] - x[28] + x[29] + x[30] + x[31] - x[32] - - x[33] + x[34] + x[35]; - out[25] = +x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] + x[8] + - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - - x[17] - x[18] + x[19] + x[20] - x[21] + x[22] - x[23] - x[24] - - x[25] - x[26] - x[27] + x[28] - x[29] + x[30] + x[31] + x[32] - - x[33] - x[34] + x[35]; - out[26] = +x[0] - x[1] - x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] + - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] + x[16] + - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - - x[25] - x[26] - x[27] - x[28] + x[29] - x[30] + x[31] + x[32] + - x[33] - x[34] - x[35]; - out[27] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - - x[9] + x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + - x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] + x[24] - - x[25] - x[26] - x[27] - x[28] - x[29] + x[30] - x[31] + x[32] + - x[33] + x[34] - x[35]; - out[28] = +x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] + x[8] + - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - - x[17] - x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] + - x[25] - x[26] - x[27] - x[28] - x[29] - x[30] + x[31] - x[32] + - x[33] + x[34] + x[35]; - out[29] = +x[0] - x[1] + x[2] + x[3] - x[4] - x[5] - x[6] + x[7] - x[8] + - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - - x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] + x[24] - - x[25] + x[26] - x[27] - x[28] - x[29] - x[30] - x[31] + x[32] - - x[33] + x[34] + x[35]; - out[30] = +x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] + x[24] + - x[25] - x[26] + x[27] - x[28] - x[29] - x[30] - x[31] - x[32] + - x[33] - x[34] + x[35]; - out[31] = +x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] + - x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] - x[23] + x[24] + - x[25] + x[26] - x[27] + x[28] - x[29] - x[30] - x[31] - x[32] - - x[33] + x[34] - x[35]; - out[32] = +x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] - - x[17] - x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + - x[25] + x[26] + x[27] - x[28] + x[29] - x[30] - x[31] - x[32] - - x[33] - x[34] + x[35]; - out[33] = +x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - - x[9] - x[10] + x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - - x[25] + x[26] + x[27] + x[28] - x[29] + x[30] - x[31] - x[32] - - x[33] - x[34] - x[35]; - out[34] = +x[0] + x[1] - x[2] + x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] + x[15] - x[16] + - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] - - x[25] - x[26] + x[27] + x[28] + x[29] - x[30] + x[31] - x[32] - - x[33] - x[34] - x[35]; - out[35] = +x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] + - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] - - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] + x[23] + x[24] + - x[25] - x[26] - x[27] + x[28] + x[29] + x[30] - x[31] + x[32] - - x[33] - x[34] - x[35]; -#pragma unroll - for (int i = 0; i < 36; i++) { - x[i] = out[i]; - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_chunk_28( - T x[kNChunks][28]) { -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - hadamard_mult_thread_28(x[c]); - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_chunk_36( - T x[kNChunks][36]) { -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - hadamard_mult_thread_36(x[c]); - } -} - -template -inline __device__ void load_input(const T *x, - T x_vals[kNChunks][VecSize], - int dim) { - using vec_t = typename BytesToType::Type; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int offset; - if constexpr (UseDiagonalBlockMatrix) { - static_assert(kNChunks == 1); - offset = blockIdx.y * blockDim.x + threadIdx.x; - } else { - offset = c * blockDim.x + threadIdx.x; - } - if (offset * VecSize < dim) { - reinterpret_cast(x_vals)[c] = - reinterpret_cast(x)[offset]; - } - } -} - -template -__forceinline__ __device__ OutType QuantHelperFunc(const InType input, - const float scale, - const int round_type, - const float max_bound, - const float min_bound) { - float quant_value = max_bound * scale * static_cast(input); - - if (round_type == 0) { - quant_value = static_cast(rint(quant_value)); - } else { - quant_value = static_cast(round(quant_value)); - } - return static_cast( - ClipFunc(quant_value, min_bound, max_bound)); -} - -template -inline __device__ void smooth_quant_store_output(OutT *out, - const T *shift, - const T *smooth, - T out_vals[kNChunks][VecSize], - const float quant_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int dim) { - using DstVec = AlignedVector; - using Vec = AlignedVector; - DstVec dst_vec; - Vec shift_vec; - Vec smooth_vec; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int base_idx; - if constexpr (UseDiagonalBlockMatrix) { - base_idx = blockIdx.y * blockDim.x + threadIdx.x; - } else { - base_idx = c * blockDim.x + threadIdx.x; - } - const int idx = base_idx * VecSize; - if (idx < dim) { - Load(shift + idx, &shift_vec); - Load(smooth + idx, &smooth_vec); -#pragma unroll - for (int vi = 0; vi < VecSize; ++vi) { - out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; - dst_vec[vi] = - QuantHelperFunc(static_cast(out_vals[c][vi]), - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); - } - Store(dst_vec, out + idx); - } - } -} - -template -inline __device__ void quant_store_output(OutT *out, - T out_vals[kNChunks][VecSize], - const float quant_scale, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int dim) { - using DstVec = AlignedVector; - using Vec = AlignedVector; - DstVec dst_vec; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int base_idx; - if constexpr (UseDiagonalBlockMatrix) { - base_idx = blockIdx.y * blockDim.x + threadIdx.x; - } else { - base_idx = c * blockDim.x + threadIdx.x; - } - const int idx = base_idx * VecSize; - if (idx < dim) { -#pragma unroll - for (int vi = 0; vi < VecSize; ++vi) { - // out_vals[c][vi] = (out_vals[c][vi] + shift_vec[vi]) * smooth_vec[vi]; - dst_vec[vi] = - QuantHelperFunc(static_cast(out_vals[c][vi]), - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound); - } - Store(dst_vec, out + idx); - } - } -} - -template -inline __device__ void store_output(OutT *out, - T out_vals[kNChunks][VecSize], - int dim) { - using vec_t = typename BytesToType::Type; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - int offset; - if constexpr (UseDiagonalBlockMatrix) { - offset = blockIdx.y * blockDim.x + threadIdx.x; - } else { - offset = c * blockDim.x + threadIdx.x; - } - if (offset * VecSize < dim) { - reinterpret_cast(out)[offset] = - reinterpret_cast(out_vals)[c]; - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread_transpose( - T x[1 << kLogN][kNChunks]) { - constexpr int N = 1 << kLogN; -#pragma unroll - for (int i = 0; i < kLogN; ++i) { - const int stride = 1 << i; -#pragma unroll - for (int j = 0; j < N / 2; ++j) { - const int lo = j & (stride - 1); - const int idx = (j - lo) * 2 + lo; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - const T a = x[idx][c]; - const T b = x[idx + stride][c]; - x[idx][c] = a + b; - x[idx + stride][c] = a - b; - } - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_thread( - T x[kNChunks][1 << kLogN]) { - constexpr int N = 1 << kLogN; -#pragma unroll - for (int i = 0; i < kLogN; ++i) { - const int stride = 1 << i; -#pragma unroll - for (int j = 0; j < N / 2; ++j) { - const int lo = j & (stride - 1); - const int idx = (j - lo) * 2 + lo; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { - const T a = x[c][idx]; - const T b = x[c][idx + stride]; - x[c][idx] = a + b; - x[c][idx + stride] = a - b; - } - } - } -} - -template -__device__ __forceinline__ void hadamard_mult_warp(T x[kNChunks][kNItems]) { - constexpr int N = 1 << kLogWarpSize; - int lane_id = threadIdx.x % N; -#pragma unroll - for (int step = kStepStart; step < kLogWarpSize; ++step) { - const int lane_mask = 1 << step; - const T sign = (lane_id & lane_mask) ? -1.f : 1.f; -#pragma unroll - for (int c = 0; c < kNChunks; ++c) { -#pragma unroll - for (int i = 0; i < kNItems; ++i) { - T x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask); - x[c][i] = sign * x[c][i] + x_val_other; - } - } - } -} - -template -inline __device__ void exchange_smem_pre(T x_vals[kNChunks][kNElts], - vec_t *smem) { - // kNChunks表示整体需要多少次循环才能处理完 - // kChunksPerExchange表示每次循环可以处理多少个chunk - // kNExchanges表示多少次循环才能处理完所有数据 - constexpr int kNThreads = kWarpSize * kNWarps; - const int warp_id = threadIdx.x / kWarpSize; - const int lane_id = threadIdx.x % kWarpSize; - const int row_t = threadIdx.x % kNWarps; - const int col_t = threadIdx.x / kNWarps; -#pragma unroll - for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) { - // 搬多少次chunk算完所有数据 - __syncthreads(); -#pragma unroll - for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { - // 每次循环搬多少数据把smem塞满 - // smem[c1 * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : - // row_t * kWarpSize + col_t ^ row_t)] = - // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); - smem[c1 * kNThreads + - (Pre ? warp_id * kWarpSize + lane_id : row_t * kWarpSize + col_t)] = - *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]); - } - __syncthreads(); -#pragma unroll - for (int c1 = 0; c1 < kChunksPerExchange; ++c1) { - // *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = - // smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : - // warp_id * kWarpSize + lane_id ^ warp_id)]; - *reinterpret_cast(x_vals[c0 * kChunksPerExchange + c1]) = - smem[c1 * kNThreads + (Pre ? row_t * kWarpSize + col_t - : warp_id * kWarpSize + lane_id)]; - } - } -} - -constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; } - -template -__global__ __launch_bounds__(kThreads) void moe_fast_hardamard_kernel( - const T *x, - const int64_t *expert_idx_per_token, - const T *shift, - const T *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - OutT *out) { - using vec_t = typename BytesToType::Type; - constexpr int kLogVecSize = cilog2(VecSize); - constexpr int kLogWarpSize = cilog2(32); - constexpr int kWarpSize = 32; - constexpr int kNWarps = kThreads / kWarpSize; - constexpr int kLogNWarps = cilog2(kNWarps); - constexpr int kLogNChunks = cilog2(kNChunks); - - extern __shared__ char smem_[]; - vec_t *smem_exchange = reinterpret_cast(smem_); - - for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { - const T *x_now = x + token_id * dim; - OutT *out_now = out + token_id * dim; - T init_value = static_cast(0.f); - T x_vals[kNChunks][VecSize] = {init_value}; - - load_input( - x_now, x_vals, dim); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id0: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - - hadamard_mult_thread(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - hadamard_mult_warp(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id2: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - if constexpr (kNWarps > 1) { - // 先让连续的NWARPS个线程拿到其余warps上的数据 - exchange_smem_pre(x_vals, smem_exchange); - // 交叉计算 - hadamard_mult_warp(x_vals); - // 再换回来 - exchange_smem_pre(x_vals, smem_exchange); - } - if constexpr (kNChunks > 1) { - if constexpr (kNChunks == 28) { - hadamard_mult_thread_28_transpose(x_vals); - } else if constexpr (kNChunks == 36) { - hadamard_mult_thread_36_transpose(x_vals); - } else { - constexpr int kLogNChunks = cilog2(kNChunks); - static_assert(1 << kLogNChunks == kNChunks, - "kNChunks must be a power of 2"); - hadamard_mult_thread_transpose(x_vals); - } - } - if (quant_scales) { - int64_t expert_id = expert_idx_per_token[token_id]; - float quant_scale = quant_scales[expert_id]; - if (shift) { - smooth_quant_store_output(out_now, - shift, - smooth, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } else { - quant_store_output( - out_now, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } - } else { - store_output( - out_now, x_vals, dim); - } - } -} - -template -__global__ __launch_bounds__(kThreads) void masked_moe_fast_hardamard_kernel( - const T *x, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - OutT *out) { - using vec_t = typename BytesToType::Type; - constexpr int kLogVecSize = cilog2(VecSize); - constexpr int kLogWarpSize = cilog2(32); - constexpr int kWarpSize = 32; - constexpr int kNWarps = kThreads / kWarpSize; - constexpr int kLogNWarps = cilog2(kNWarps); - constexpr int kLogNChunks = cilog2(kNChunks); - - extern __shared__ char smem_[]; - vec_t *smem_exchange = reinterpret_cast(smem_); - - for (int token_id = blockIdx.x; token_id < token_num; token_id += gridDim.x) { - const auto token_idx_in_expert = token_id % num_max_tokens_per_expert; - const auto expert_id = token_id / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = - (next_expert_start_idx - token_id - 1) / gridDim.x; - token_id += num_iters_to_next_expert * gridDim.x; - continue; - } - const T *x_now = x + token_id * dim; - OutT *out_now = out + token_id * dim; - T init_value = static_cast(0.f); - T x_vals[kNChunks][VecSize] = {init_value}; - - load_input( - x_now, x_vals, dim); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id0: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - - hadamard_mult_thread(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id1: %d, kLogVecSize: %d\n", i, kLogVecSize); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - hadamard_mult_warp(x_vals); -#ifdef DEBUG_HARDAMARD - if (blockIdx.x == 0 && threadIdx.x == 0) { - for (int i = 0; i < 1; ++i) { - printf("chunk_id2: %d\n", i); - for (int j = 0; j < VecSize; ++j) { - printf("%f ", (float)x_vals[i][j]); - } - printf("\n"); - } - } - __syncthreads(); -#endif - if constexpr (kNWarps > 1) { - // 先让连续的NWARPS个线程拿到其余warps上的数据 - exchange_smem_pre(x_vals, smem_exchange); - // 交叉计算 - hadamard_mult_warp(x_vals); - // 再换回来 - exchange_smem_pre(x_vals, smem_exchange); - } - if constexpr (kNChunks > 1) { - if constexpr (kNChunks == 28) { - hadamard_mult_thread_28_transpose(x_vals); - } else if constexpr (kNChunks == 36) { - hadamard_mult_thread_36_transpose(x_vals); - } else { - constexpr int kLogNChunks = cilog2(kNChunks); - static_assert(1 << kLogNChunks == kNChunks, - "kNChunks must be a power of 2"); - hadamard_mult_thread_transpose(x_vals); - } - } - if (quant_scales) { - float quant_scale = quant_scales[expert_id]; - if (shift) { - smooth_quant_store_output(out_now, - shift, - smooth, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } else { - quant_store_output( - out_now, - x_vals, - quant_scale, - quant_round_type, - quant_max_bound, - quant_min_bound, - dim); - } - } else { - store_output( - out_now, x_vals, dim); - } - } -} - -template -void MoeFastHardamardImplWrapper(const T *x, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - OutT *out, - cudaStream_t stream) { - using nv_type = typename nv_type_traits::type; - using out_type = typename nv_type_traits::type; - constexpr int kNBytes = sizeof(T); - constexpr int N = 1 << kLogN; // pad - constexpr int kSmemSize = std::min(N * kNBytes, 32 * 1024); - constexpr int kRounds = N * kNBytes / kSmemSize; - constexpr int kChunksPerSmemSize = kSmemSize / (kThreads * VecSize * kNBytes); - VLOG(1) << "real_dim: " << dim << ", N: " << N; - VLOG(1) << "kNChunks: " << kNChunks; - VLOG(1) << "kNBytes: " << kNBytes; - VLOG(1) << "kSmemSize: " << kSmemSize; - VLOG(1) << "kRounds: " << kRounds; - VLOG(1) << "kChunksPerSmemSize: " << kChunksPerSmemSize; - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - - if (used_in_ep_low_latency) { - auto masked_kernel = - masked_moe_fast_hardamard_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, masked_kernel, kThreads, kSmemSize); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - if constexpr (UseDiagonalBlockMatrix) { - grid.y = ceil(dim / (kThreads * VecSize)); - } - masked_kernel<<>>( - reinterpret_cast(x), - recv_expert_count, - reinterpret_cast(shift), - reinterpret_cast(smooth), - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - reinterpret_cast(out)); - } else { - auto kernel = moe_fast_hardamard_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, kThreads, kSmemSize); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - if constexpr (UseDiagonalBlockMatrix) { - grid.y = ceil(dim / (kThreads * VecSize)); - } - kernel<<>>( - reinterpret_cast(x), - expert_idx_per_token, - reinterpret_cast(shift), - reinterpret_cast(smooth), - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - reinterpret_cast(out)); - } -} - -template -void MoeFastHardamardWrapper(const T *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const T *shift, - const T *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - OutT *out, - cudaStream_t &stream) { - bool FLAGS_hardamard_use_diagonal_block_matrix = true; - - constexpr int kThreads = 128; - if (FLAGS_hardamard_use_diagonal_block_matrix) { - const int VecSize = hadamard_block_size / kThreads; - const int logN = int(ceil(std::log2(kThreads * VecSize))); - constexpr int kNChunks = 1; - DISPATCH_SP_VS(VecSize, VEC_SIZE, {DISPATCH_SP_logN(logN, kLogN, { - MoeFastHardamardImplWrapper( - x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - })}); - } else { - if (!((dim / 28) & (dim / 28 - 1))) { - VLOG(1) << "28 * 2^n"; - const int logN = int(ceil(std::log2(dim / 28))); - constexpr int kNChunks = 28; - DISPATCH_SP_logN(logN, kLogN, { - constexpr int VecSize = (1 << kLogN) / kThreads; - MoeFastHardamardImplWrapper(x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } else if (!((dim / 36) & (dim / 36 - 1))) { - VLOG(1) << "36 * 2^n"; - const int logN = int(ceil(std::log2(dim / 36))); - constexpr int kNChunks = 36; - DISPATCH_SP_logN(logN, kLogN, { - constexpr int VecSize = (1 << kLogN) / kThreads; - MoeFastHardamardImplWrapper(x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } else { - VLOG(1) << "2^n"; - const int logN = int(ceil(std::log2(dim))); - constexpr int VecSize = 16 / sizeof(T); - DISPATCH_logN(logN, kLogN, { - constexpr int kNChunks = (1 << kLogN) / (kThreads * VecSize); - MoeFastHardamardImplWrapper(x_data, - expert_idx_per_token, - recv_expert_count, - shift, - smooth, - quant_scales, - quant_round_type, - quant_max_bound, - quant_min_bound, - token_num, - dim, - num_max_tokens_per_expert, - used_in_ep_low_latency, - out, - stream); - }); - } - } -} diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu deleted file mode 100644 index cc5e19b21..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.hpp" - -template void -MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::bfloat16 *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu deleted file mode 100644 index 51298b87a..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.hpp" - -template void -MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::float8_e4m3fn *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu deleted file mode 100644 index 4a2faeec7..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.hpp" - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu deleted file mode 100644 index 9f466b204..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.hpp" - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::float16 *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu deleted file mode 100644 index e86726f86..000000000 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fast_hardamard_kernel.hpp" - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float *quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream); diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 817e9d92b..703a7c11f 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -25,8 +25,7 @@ template __global__ void moe_token_type_ids_kernel(T *gating_output, const int *moe_token_type_ids_out, const int num_rows, - const int num_experts, - const int k) { + const int num_experts, const int k) { const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x; if (moe_token_index >= num_rows) { @@ -45,8 +44,7 @@ template void moe_token_type_ids_kernelLauncher(T *gating_output, const int *moe_token_type_ids_out, const int num_rows, - const int num_experts, - const int k, + const int num_experts, const int k, cudaStream_t stream) { const int blocks = num_rows * k / 512 + 1; const int threads = 512; @@ -54,35 +52,26 @@ void moe_token_type_ids_kernelLauncher(T *gating_output, gating_output, moe_token_type_ids_out, num_rows, num_experts, k); } -template -class MoeHelper { - public: - using Fp16Traits = - cutlass::WintQuantTraits; - using Int8Traits = - cutlass::WintQuantTraits; - using Int4Traits = - cutlass::WintQuantTraits; +template class MoeHelper { +public: + using Fp16Traits = cutlass::WintQuantTraits; + using Int8Traits = cutlass::WintQuantTraits; + using Int4Traits = cutlass::WintQuantTraits; - MoeHelper(const std::string gemm_method, - MoeGemmRunner *fp16_moe_gemm_runner, - MoeGemmRunner *int8_moe_gemm_runner, - MoeGemmRunner *int4_moe_gemm_runner, - int layernum = 0) - : gemm_method_(gemm_method), - fp16_moe_gemm_runner_(fp16_moe_gemm_runner), + MoeHelper( + const std::string gemm_method, + MoeGemmRunner *fp16_moe_gemm_runner, + MoeGemmRunner *int8_moe_gemm_runner, + MoeGemmRunner *int4_moe_gemm_runner, + int layernum = 0) + : gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner), int8_moe_gemm_runner_(int8_moe_gemm_runner), - int4_moe_gemm_runner_(int4_moe_gemm_runner), - layernum_(layernum) {} + int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {} // -------- getWorkspaceSize -------- // template - size_t getWorkspaceSize(const int64_t num_rows, - const int64_t hidden_size, - const int64_t inter_size, - const int64_t num_experts, + size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size, + const int64_t inter_size, const int64_t num_experts, const int64_t k) { const size_t buf_size = AlignTo16(k * num_rows * hidden_size); const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); @@ -93,10 +82,10 @@ class MoeHelper { // FfnLayer forward. size_t total_ws_bytes = 5 * num_moe_inputs * - sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ - total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data total_ws_bytes += - padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); const size_t sorter_ws_size_bytes = @@ -111,8 +100,8 @@ class MoeHelper { } total_ws_bytes += - bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub - // sorting workspace + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace int64_t num_softmax_outs = 0; const bool is_pow_2 = @@ -126,27 +115,20 @@ class MoeHelper { return total_ws_bytes; } - void ComputeFFN(const paddle::Tensor *input, - const paddle::Tensor *gate_weight, - const paddle::Tensor *up_gate_proj_weight, - const paddle::Tensor *up_gate_proj_scale, - const paddle::Tensor *up_gate_proj_bias, - const paddle::Tensor *down_proj_weight, - const paddle::Tensor *down_proj_scale, - const paddle::Tensor *down_proj_bias, - const paddle::Tensor *moe_token_type_ids, - const int moe_topk, - const bool group_moe, - const bool norm_topk_prob, - const float routed_scaling_factor, - const std::string moe_type, - paddle::Tensor *output) { + void + ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight, + const paddle::Tensor *up_gate_proj_weight, + const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias, + const paddle::Tensor *down_proj_weight, + const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias, + const paddle::Tensor *moe_token_type_ids, const int moe_topk, + const bool group_moe, const bool norm_topk_prob, + const float routed_scaling_factor, const std::string moe_type, + paddle::Tensor *output) { auto *input_activations = input->data(); auto *gating_weights = gate_weight->data(); - const T *fc1_expert_biases = - up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; - const T *fc2_expert_biases = - down_proj_bias ? down_proj_bias->data() : nullptr; + const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; + const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data() : nullptr; auto *output_ = output->data(); auto stream = input->stream(); @@ -166,8 +148,7 @@ class MoeHelper { const int64_t hidden_size = up_gate_proj_dims[1]; int64_t inter_dim = 0; if (moe_type == "qkv") { - inter_dim = - up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; + inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; } else { inter_dim = up_gate_proj_dims[2]; } @@ -251,79 +232,44 @@ class MoeHelper { if (moe_token_type_ids) { auto *moe_token_type_ids_out = moe_token_type_ids->data(); moe_token_type_ids_kernelLauncher(gating_output, - moe_token_type_ids_out, - num_rows, - num_experts, - k, - stream); + moe_token_type_ids_out, num_rows, + num_experts, k, stream); } - topk_gating_softmax_kernelLauncher(gating_output, - nullptr, - expert_scales_float, - softmax_out_, - expert_for_source_row, - source_rows_, - softmax_max_prob, - num_rows, - num_experts, - k, - group_moe, - stream); + topk_gating_softmax_kernelLauncher( + gating_output, nullptr, expert_scales_float, softmax_out_, + expert_for_source_row, source_rows_, softmax_max_prob, num_rows, + num_experts, k, group_moe, stream); const int64_t sorter_ws_size_bytes = AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); - sorter_.run(fc1_result_, - sorter_ws_size_bytes, - expert_for_source_row, - permuted_experts_, - source_rows_, - permuted_rows_, - k * num_rows, - false, - stream); + sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row, + permuted_experts_, source_rows_, permuted_rows_, k * num_rows, + false, stream); initialize_moe_routing_kernelLauncher( - input_activations, - permuted_data_, - permuted_rows_, - nullptr, - nullptr, - expanded_source_row_to_expanded_dest_row, - nullptr, - num_rows, - num_rows, - hidden_size, - k, - stream); + input_activations, permuted_data_, permuted_rows_, nullptr, nullptr, + expanded_source_row_to_expanded_dest_row, num_rows, num_rows, + hidden_size, k, stream); const int64_t expanded_active_expert_rows = k * num_rows; compute_total_rows_before_expert(permuted_experts_, - expanded_active_expert_rows, - num_experts, - total_rows_before_expert_, - stream); + expanded_active_expert_rows, num_experts, + total_rows_before_expert_, stream); if (gemm_method_ == "weight_only_int8") { typename Int8Traits::Arguments up_gate_proj_quant_args; int8_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast( - up_gate_proj_weight->data()), + reinterpret_cast(up_gate_proj_weight->data()), reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - inter_size, - hidden_size, - num_experts, - up_gate_proj_quant_args, - "none", - stream); + reinterpret_cast(fc1_out), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, inter_size, hidden_size, num_experts, + up_gate_proj_quant_args, "none", stream); } else if (gemm_method_ == "weight_only_int4") { typename Int4Traits::Arguments up_gate_proj_quant_args; int4_moe_gemm_runner_->moe_gemm_bias_act( @@ -332,33 +278,20 @@ class MoeHelper { up_gate_proj_weight->data()), reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - inter_size, - hidden_size, - num_experts, - up_gate_proj_quant_args, - "none", - stream); + reinterpret_cast(fc1_out), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, inter_size, hidden_size, num_experts, + up_gate_proj_quant_args, "none", stream); } else { typename Fp16Traits::Arguments up_gate_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast(up_gate_proj_weight->data()), - nullptr, + reinterpret_cast(up_gate_proj_weight->data()), nullptr, reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - inter_size, - hidden_size, - num_experts, - up_gate_proj_quant_args, - "none", - stream); + reinterpret_cast(fc1_out), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, inter_size, hidden_size, num_experts, + up_gate_proj_quant_args, "none", stream); } if (moe_type == "ffn") { @@ -376,15 +309,10 @@ class MoeHelper { reinterpret_cast(act_out), reinterpret_cast(down_proj_weight->data()), reinterpret_cast(down_proj_scale->data()), - reinterpret_cast(fc2_result), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - hidden_size, - inter_size / 2, - num_experts, - down_proj_quant_args, - stream); + reinterpret_cast(fc2_result), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, hidden_size, inter_size / 2, + num_experts, down_proj_quant_args, stream); } else if (gemm_method_ == "weight_only_int4") { typename Int4Traits::Arguments down_proj_quant_args; int4_moe_gemm_runner_->moe_gemm( @@ -392,66 +320,40 @@ class MoeHelper { reinterpret_cast( down_proj_weight->data()), reinterpret_cast(down_proj_scale->data()), - reinterpret_cast(fc2_result), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - hidden_size, - inter_size / 2, - num_experts, - down_proj_quant_args, - stream); + reinterpret_cast(fc2_result), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, hidden_size, inter_size / 2, + num_experts, down_proj_quant_args, stream); } else { typename Fp16Traits::Arguments down_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(down_proj_weight->data()), - nullptr, - reinterpret_cast(fc2_result), - total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, - hidden_size, - inter_size / 2, - num_experts, - down_proj_quant_args, - stream); + reinterpret_cast(down_proj_weight->data()), nullptr, + reinterpret_cast(fc2_result), total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, hidden_size, inter_size / 2, + num_experts, down_proj_quant_args, stream); } finalize_moe_routing_kernelLauncher( - fc2_result, - output_, - fc2_expert_biases, + fc2_result, output_, fc2_expert_biases, reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, - expert_for_source_row, - num_rows, - hidden_size, - k, - static_cast(1), - norm_topk_prob, - routed_scaling_factor, - stream); + expanded_source_row_to_expanded_dest_row, expert_for_source_row, + num_rows, hidden_size, k, static_cast(1), norm_topk_prob, + routed_scaling_factor, stream); } else { finalize_moe_routing_kernelLauncher( // fc2_result, - fc1_out, - output_, - fc1_expert_biases, // fc2_expert_biases, + fc1_out, output_, + fc1_expert_biases, // fc2_expert_biases, reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, - expert_for_source_row, - num_rows, - inter_size, - k, - static_cast(0), - norm_topk_prob, - routed_scaling_factor, - stream); + expanded_source_row_to_expanded_dest_row, expert_for_source_row, + num_rows, inter_size, k, static_cast(0), norm_topk_prob, + routed_scaling_factor, stream); } } - private: +private: std::string gemm_method_; MoeGemmRunner *fp16_moe_gemm_runner_; MoeGemmRunner *int8_moe_gemm_runner_; @@ -460,4 +362,4 @@ class MoeHelper { CubKeyValueSorter sorter_; }; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 50a75caab..eeaecb716 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -19,9 +19,9 @@ #include #include -#include "cutlass/numeric_conversion.h" -#include "moe/fused_moe_helper.h" #include "moe/fused_moe_imp_op.h" +#include "moe/fused_moe_helper.h" +#include "cutlass/numeric_conversion.h" // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -37,8 +37,8 @@ namespace phi { struct GpuLaunchConfig { - dim3 block_per_grid; - dim3 thread_per_block; + dim3 block_per_grid; + dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { @@ -59,59 +59,55 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; template -__host__ __device__ constexpr static U arrayConvert(T const& input) { - using Type = typename U::Element; - static_assert(T::kElements == U::kElements); - U u; +__host__ __device__ constexpr static U arrayConvert(T const& input) +{ + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; #pragma unroll - for (int i = 0; i < U::kElements; i++) { - u[i] = static_cast(input[i]); - } - return u; + for (int i = 0; i < U::kElements; i++) + { + u[i] = static_cast(input[i]); + } + return u; } struct uint8 { - uint4 u; - uint4 v; + uint4 u; + uint4 v; }; -template -struct BytesToType {}; +template struct BytesToType {}; -template <> +template<> struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); + using Type = uint8; + static_assert(sizeof(Type) == 32); }; -template <> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template <> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template <> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template <> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template <> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; template