From ae7bee81228b56590f0918ce0176ddb904de63c4 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com> Date: Thu, 13 Nov 2025 19:17:27 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90New=20Feature=E3=80=91W4afp8=20support?= =?UTF-8?q?s=20per=20group=20quantization=20(#4987)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * w4afp8 支持per group * code style * fix transpose * revert fast hardmard --------- Co-authored-by: yuanxiaolan Co-authored-by: plusNew001 <95567040+plusNew001@users.noreply.github.com> --- custom_ops/gpu_ops/cpp_extensions.cc | 1 + .../gpu_ops/moe/ep_moe_expert_dispatch.cu | 1208 +++++++++-------- custom_ops/gpu_ops/moe/fused_moe_helper.h | 272 ++-- custom_ops/gpu_ops/moe/fused_moe_op.h | 829 +++++------ custom_ops/gpu_ops/moe/moe_dispatch.cu | 283 ++-- custom_ops/gpu_ops/moe/moe_ffn.cu | 152 ++- custom_ops/gpu_ops/moe/template_config.json | 3 +- .../gpu_ops/w4afp8_gemm/kernel_traits.h | 212 +-- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 858 +++++++----- custom_ops/gpu_ops/w4afp8_gemm/utils.hpp | 149 +- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 420 +++--- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 48 +- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 431 +++--- .../gpu_ops/w4afp8_gemm/weight_kernel.hpp | 131 ++ .../w4afp8_gemm/weight_scale_kernel.hpp | 63 + .../utils/auto_gen_w4afp8_gemm_kernel.py | 104 +- fastdeploy/model_executor/layers/moe/ep.py | 2 + .../layers/moe/fused_moe_cutlass_backend.py | 134 +- .../layers/moe/fused_moe_wint2_backend.py | 1 + .../layers/quantization/mix_quant.py | 4 + tests/operators/test_w4afp8_gemm.py | 57 +- 21 files changed, 3114 insertions(+), 2248 deletions(-) create mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp create mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 15dd61ad4..6ecc1ed14 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -304,6 +304,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index 33b93c80b..8a8fb1116 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -23,104 +22,104 @@ #include "moe/fused_moe_op.h" #pragma GCC diagnostic pop -#include "helper.h" #include +#include "helper.h" -#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \ - switch (num_experts_per_rank) { \ - case 2: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 3: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 3; \ - __VA_ARGS__ \ - break; \ - } \ - case 6: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 6; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 9: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 9; \ - __VA_ARGS__ \ - break; \ - } \ - case 16: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 32: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 32; \ - __VA_ARGS__ \ - break; \ - } \ - case 48: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 48; \ - __VA_ARGS__ \ - break; \ - } \ - case 64: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 128: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 160: { \ - constexpr size_t NUM_EXPERTS_PER_RANK = 160; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \ - throw std::invalid_argument(err_msg.str()); \ - } \ +#define DISPATCH_NUM_EXPERTS_PER_RANK( \ + num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \ + switch (num_experts_per_rank) { \ + case 2: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 3: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 3; \ + __VA_ARGS__ \ + break; \ + } \ + case 6: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 6; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 9: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 9; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 32: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 32; \ + __VA_ARGS__ \ + break; \ + } \ + case 48: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 48; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 160: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 160; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \ + throw std::invalid_argument(err_msg.str()); \ + } \ } namespace cg = cooperative_groups; -template -__device__ T warpReduceSum(T val){ - for(int lane_mask = 16; lane_mask > 0; lane_mask /=2){ - val += __shfl_down_sync(0xffffffff, val, lane_mask); - } - return val; +template +__device__ T warpReduceSum(T val) { + for (int lane_mask = 16; lane_mask > 0; lane_mask /= 2) { + val += __shfl_down_sync(0xffffffff, val, lane_mask); + } + return val; } -__global__ void get_expert_token_num( - int64_t* topk_ids, - int *out_workspace, // num_experts * 2 + 2 - const int token_num, - const int moe_topk, - const int num_experts -) { +__global__ void get_expert_token_num(int64_t* topk_ids, + int* out_workspace, // num_experts * 2 + 2 + const int token_num, + const int moe_topk, + const int num_experts) { cg::grid_group grid = cg::this_grid(); constexpr int KNWARPS = 512 / 32; __shared__ int warp_sum[KNWARPS * 2]; - int *expert_token_num = out_workspace; - int *expert_token_num_padded = out_workspace + num_experts; - int *token_num_all = out_workspace + num_experts * 2; - int *token_num_all_padded = out_workspace + num_experts * 2 + 1; + int* expert_token_num = out_workspace; + int* expert_token_num_padded = out_workspace + num_experts; + int* token_num_all = out_workspace + num_experts * 2; + int* token_num_all_padded = out_workspace + num_experts * 2 + 1; const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; for (int i = global_idx; i < num_experts; i += blockDim.x * gridDim.x) { expert_token_num[i] = 0; expert_token_num_padded[i] = 0; } grid.sync(); - for (int i = global_idx; i < token_num * moe_topk; i += blockDim.x * gridDim.x) { + for (int i = global_idx; i < token_num * moe_topk; + i += blockDim.x * gridDim.x) { const int topk_idx = topk_ids[i]; atomicAdd(&expert_token_num[topk_idx], 1); } @@ -128,7 +127,8 @@ __global__ void get_expert_token_num( for (int i = global_idx; i < num_experts; i += blockDim.x * gridDim.x) { const int token_num_per_expert = expert_token_num[i]; if (token_num_per_expert > 0) { - expert_token_num_padded[i] = 128 - token_num_per_expert % 128 + token_num_per_expert; + expert_token_num_padded[i] = + 128 - token_num_per_expert % 128 + token_num_per_expert; } } grid.sync(); @@ -163,29 +163,36 @@ __global__ void get_expert_token_num( } } -std::vector> GetExpertTokenNum( - const paddle::Tensor& topk_ids, - const int num_experts) { +std::vector> GetExpertTokenNum(const paddle::Tensor& topk_ids, + const int num_experts) { const int token_num = topk_ids.dims()[0]; const int moe_topk = topk_ids.dims()[1]; - auto out_workspace = GetEmptyTensor({num_experts * 2 + 2}, paddle::DataType::INT32, topk_ids.place()); + auto out_workspace = GetEmptyTensor( + {num_experts * 2 + 2}, paddle::DataType::INT32, topk_ids.place()); const int block_size = 512; const int grid_size = min(132 * 4, div_up(token_num * moe_topk, block_size)); - int64_t *topk_ids_ptr = const_cast(topk_ids.data()); - int *out_workspace_ptr = out_workspace.data(); - void* kernel_args[] = { - (void*)(&topk_ids_ptr), - (void*)(&out_workspace_ptr), - (void*)&token_num, - (void*)&moe_topk, - (void*)&num_experts - }; - cudaLaunchCooperativeKernel((void*)get_expert_token_num, dim3(grid_size), dim3(block_size), kernel_args, 0, topk_ids.stream()); + int64_t* topk_ids_ptr = const_cast(topk_ids.data()); + int* out_workspace_ptr = out_workspace.data(); + void* kernel_args[] = {(void*)(&topk_ids_ptr), + (void*)(&out_workspace_ptr), + (void*)&token_num, + (void*)&moe_topk, + (void*)&num_experts}; + cudaLaunchCooperativeKernel((void*)get_expert_token_num, + dim3(grid_size), + dim3(block_size), + kernel_args, + 0, + topk_ids.stream()); auto out_workspace_host = out_workspace.copy_to(paddle::CPUPlace(), true); - int *out_workspace_host_ptr = out_workspace_host.data(); - std::vector expert_token_num(out_workspace_host_ptr, out_workspace_host_ptr + num_experts); - std::vector expert_token_num_padded(out_workspace_host_ptr + num_experts, out_workspace_host_ptr + num_experts * 2); - std::vector token_num_all(out_workspace_host_ptr + num_experts * 2, out_workspace_host_ptr + num_experts * 2 + 2); + int* out_workspace_host_ptr = out_workspace_host.data(); + std::vector expert_token_num(out_workspace_host_ptr, + out_workspace_host_ptr + num_experts); + std::vector expert_token_num_padded( + out_workspace_host_ptr + num_experts, + out_workspace_host_ptr + num_experts * 2); + std::vector token_num_all(out_workspace_host_ptr + num_experts * 2, + out_workspace_host_ptr + num_experts * 2 + 2); return {expert_token_num, expert_token_num_padded, token_num_all}; } @@ -195,8 +202,9 @@ __global__ void combine_prmt_back_kernel( T* reduced_unpermuted_output, const T* bias, const float* dst_weights, - const int* expanded_source_row_to_expanded_dest_row, // permute_indices_per_token - const int* expert_for_source_row, // dst_idx + const int* + expanded_source_row_to_expanded_dest_row, // permute_indices_per_token + const int* expert_for_source_row, // dst_idx const int64_t cols, const int64_t k, const int64_t compute_bias, @@ -208,14 +216,15 @@ __global__ void combine_prmt_back_kernel( AlignedVector bias_vec; AlignedVector res_vec; const int cols_int4 = cols / VEC_SIZE; - for (int original_row = blockIdx.x; original_row < num_rows; original_row += gridDim.x) { + for (int original_row = blockIdx.x; original_row < num_rows; + original_row += gridDim.x) { T* reduced_row_ptr = reduced_unpermuted_output + original_row * cols; for (int tid = threadIdx.x; tid < cols_int4; tid += blockDim.x) { #pragma unroll for (int vid = 0; vid < VEC_SIZE; vid++) { res_vec[vid] = 0; } - for (int k_idx = 0; k_idx < k; ++k_idx) { // k is num_experts_per_rank + for (int k_idx = 0; k_idx < k; ++k_idx) { // k is num_experts_per_rank const int expanded_original_row = original_row + k_idx * num_rows; const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; @@ -223,23 +232,27 @@ __global__ void combine_prmt_back_kernel( const int64_t k_offset = original_row * k + k_idx; const float row_scale = dst_weights[expanded_permuted_row]; const T* expanded_permuted_rows_row_ptr = - expanded_permuted_rows + expanded_permuted_row * cols; // prmt后的位置对应的值 - Load(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, &load_vec); - const int expert_idx = expert_for_source_row[k_offset]; // 当前位置对应的专家 - const T* bias_ptr = bias ? bias + expert_idx * cols : nullptr; // 当前专家对应的down_proj的bias + expanded_permuted_rows + + expanded_permuted_row * cols; // prmt后的位置对应的值 + Load(expanded_permuted_rows_row_ptr + tid * VEC_SIZE, + &load_vec); + const int expert_idx = + expert_for_source_row[k_offset]; // 当前位置对应的专家 + const T* bias_ptr = bias ? bias + expert_idx * cols + : nullptr; // 当前专家对应的down_proj的bias if (bias_ptr) { Load(bias_ptr + tid * VEC_SIZE, &bias_vec); #pragma unroll for (int vid = 0; vid < VEC_SIZE; vid++) { - res_vec[vid] += static_cast( - row_scale * static_cast(load_vec[vid]) + - static_cast(bias_vec[vid])); + res_vec[vid] += + static_cast(row_scale * static_cast(load_vec[vid]) + + static_cast(bias_vec[vid])); } } else { #pragma unroll for (int vid = 0; vid < VEC_SIZE; vid++) { - res_vec[vid] += static_cast( - row_scale * static_cast(load_vec[vid])); + res_vec[vid] += + static_cast(row_scale * static_cast(load_vec[vid])); } } } @@ -259,186 +272,235 @@ void MoeCombineKernel(const paddle::Tensor& ffn_out, const int num_rows, const int hidden_size, paddle::Tensor* output) { - using namespace phi; - typedef PDTraits traits_; - typedef typename traits_::DataType DataType_; - typedef typename traits_::data_t data_t; - auto stream = ffn_out.stream(); - const int threads = 1024; - const int gridx = min(132 * 8, num_rows); - const int num_experts_per_rank = top_k_indices.dims()[1]; + using namespace phi; + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto stream = ffn_out.stream(); + const int threads = 1024; + const int gridx = min(132 * 8, num_rows); + const int num_experts_per_rank = top_k_indices.dims()[1]; - combine_prmt_back_kernel<<>>( - ffn_out.data(), - output->data(), - down_proj_bias ? down_proj_bias->data() : nullptr, - expert_scales_float.data(), - permute_indices_per_token.data(), - top_k_indices.data(), - hidden_size, - num_experts_per_rank, - static_cast(1), // compute bias - norm_topk_prob, - routed_scaling_factor, - num_rows); + combine_prmt_back_kernel<<>>( + ffn_out.data(), + output->data(), + down_proj_bias ? down_proj_bias->data() : nullptr, + expert_scales_float.data(), + permute_indices_per_token.data(), + top_k_indices.data(), + hidden_size, + num_experts_per_rank, + static_cast(1), // compute bias + norm_topk_prob, + routed_scaling_factor, + num_rows); } std::vector EPMoeExpertCombine( const paddle::Tensor& ffn_out, - const paddle::Tensor& expert_scales_float, // dst_weights - const paddle::Tensor& permute_indices_per_token, // permute_indices_per_token - const paddle::Tensor& top_k_indices, // dst_indices + const paddle::Tensor& expert_scales_float, // dst_weights + const paddle::Tensor& + permute_indices_per_token, // permute_indices_per_token + const paddle::Tensor& top_k_indices, // dst_indices const paddle::optional& down_proj_bias, const bool norm_topk_prob, const float routed_scaling_factor) { + const auto input_type = ffn_out.dtype(); + auto place = ffn_out.place(); - const auto input_type = ffn_out.dtype(); - auto place = ffn_out.place(); + const int num_rows = top_k_indices.dims()[0]; + const int hidden_size = ffn_out.dims()[1]; - const int num_rows = top_k_indices.dims()[0]; - const int hidden_size = ffn_out.dims()[1]; + auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); - auto output = GetEmptyTensor({num_rows, hidden_size}, input_type, place); - - switch (input_type) { - case paddle::DataType::BFLOAT16: - MoeCombineKernel( - ffn_out, - expert_scales_float, - permute_indices_per_token, - top_k_indices, - down_proj_bias, - norm_topk_prob, - routed_scaling_factor, - num_rows, - hidden_size, - &output); - break; - case paddle::DataType::FLOAT16: - MoeCombineKernel( - ffn_out, - expert_scales_float, - permute_indices_per_token, - top_k_indices, - down_proj_bias, - norm_topk_prob, - routed_scaling_factor, - num_rows, - hidden_size, - &output); - break; - default: - PD_THROW("Unsupported data type for MoeDispatchKernel"); - } - return {output}; + switch (input_type) { + case paddle::DataType::BFLOAT16: + MoeCombineKernel(ffn_out, + expert_scales_float, + permute_indices_per_token, + top_k_indices, + down_proj_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + &output); + break; + case paddle::DataType::FLOAT16: + MoeCombineKernel(ffn_out, + expert_scales_float, + permute_indices_per_token, + top_k_indices, + down_proj_bias, + norm_topk_prob, + routed_scaling_factor, + num_rows, + hidden_size, + &output); + break; + default: + PD_THROW("Unsupported data type for MoeDispatchKernel"); + } + return {output}; } - -template -__global__ void permute_x_kernel(const T *src_x, - const int64_t *topk_idx, - const float *topk_weights, - const int *token_nums_per_expert, - const float *up_gate_proj_in_scale, - const int moe_topk, - const int num_rows, - const int token_nums_this_rank, - const int hidden_size, - OutT *permute_x, // [token_nums_this_rank, hidden_size] - int *permute_indices_per_token, // [moe_topk, num_rows] - float *dst_weights, // [token_nums_this_rank] - int *dst_indices, - int *cumsum_idx_gpu, - int64_t *token_nums_per_expert_cumsum, - int64_t *expert_idx_per_token, // [num_rows, moe_topk] - float max_bound = 127.0, - float min_bound = -127.0) { - const int src_token_idx = blockIdx.x; - const int tid = threadIdx.x; - constexpr int vec_size = sizeof(int4) / sizeof(T); - __shared__ int write_idx; // cumsum start idx - __shared__ int token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK]; - AlignedVector src_vec; - AlignedVector res_vec; - if (tid == 0) { - int sum_now = 0; - for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) { - sum_now += token_nums_per_expert[i]; - token_nums_per_expert_cum[i] = sum_now; - if (blockIdx.x == 0) { - token_nums_per_expert_cumsum[i] = sum_now; - } +template +__global__ void permute_x_kernel( + const T* src_x, + const int64_t* topk_idx, + const float* topk_weights, + const int* token_nums_per_expert, + const float* up_gate_proj_in_scale, + const int moe_topk, + const int num_rows, + const int token_nums_this_rank, + const int hidden_size, + OutT* permute_x, // [token_nums_this_rank, hidden_size] + int* permute_indices_per_token, // [moe_topk, num_rows] + float* dst_weights, // [token_nums_this_rank] + int* dst_indices, + int* cumsum_idx_gpu, + int64_t* token_nums_per_expert_cumsum, + int64_t* expert_idx_per_token, // [num_rows, moe_topk] + float* dequant_scale, + float max_bound = 127.0, + float min_bound = -127.0) { + const int src_token_idx = blockIdx.x; + const int tid = threadIdx.x; + constexpr int vec_size = sizeof(int4) / sizeof(T); + __shared__ int write_idx; // cumsum start idx + __shared__ int token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK]; + extern __shared__ char smem_[]; + T* data_smem = reinterpret_cast(smem_); + AlignedVector src_vec; + AlignedVector res_vec; + if (tid == 0) { + int sum_now = 0; + for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) { + sum_now += token_nums_per_expert[i]; + token_nums_per_expert_cum[i] = sum_now; + if (blockIdx.x == 0) { + token_nums_per_expert_cumsum[i] = sum_now; } } - __syncthreads(); - const int hidden_size_int4 = hidden_size / vec_size; - for (int s_token_idx = src_token_idx; s_token_idx < num_rows; s_token_idx += gridDim.x) { - const int64_t *topk_idx_now = topk_idx + s_token_idx * moe_topk; + } + __syncthreads(); + const int hidden_size_int4 = hidden_size / vec_size; + for (int s_token_idx = src_token_idx; s_token_idx < num_rows; + s_token_idx += gridDim.x) { + const int64_t* topk_idx_now = topk_idx + s_token_idx * moe_topk; #pragma unroll - for (int expert_idx = 0; expert_idx < moe_topk; expert_idx++) { - int expert_now = static_cast(topk_idx_now[expert_idx]); - if (expert_now == -1) continue; - const int dst_chunk_start_idx = expert_now == 0 ? 0 : token_nums_per_expert_cum[expert_now - 1]; - if (tid == 0) { - const int offset_now = atomicAdd(cumsum_idx_gpu + expert_now, 1); - write_idx = offset_now; + for (int expert_idx = 0; expert_idx < moe_topk; expert_idx++) { + int expert_now = static_cast(topk_idx_now[expert_idx]); + if (expert_now == -1) continue; + const int dst_chunk_start_idx = + expert_now == 0 ? 0 : token_nums_per_expert_cum[expert_now - 1]; + if (tid == 0) { + const int offset_now = atomicAdd(cumsum_idx_gpu + expert_now, 1); + write_idx = offset_now; + } + __syncthreads(); + const int token_offset_now = write_idx; + const int dst_token_idx = dst_chunk_start_idx + token_offset_now; + permute_indices_per_token[expert_now * num_rows + s_token_idx] = + dst_token_idx; + dst_weights[dst_token_idx] = + topk_weights[s_token_idx * moe_topk + expert_idx]; + dst_indices[s_token_idx * NUM_EXPERTS_PER_RANK + expert_now] = expert_now; + // cp x + if (dequant_scale) { // dynamic quant + float abs_max = 0.0f; + for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { + Load(&src_x[s_token_idx * hidden_size + v_id * vec_size], + &src_vec); + Store(src_vec, &data_smem[v_id * vec_size]); +#pragma unroll + for (int i = 0; i < vec_size; i++) { + abs_max = fmaxf(abs_max, fabsf(static_cast(src_vec[i]))); } - __syncthreads(); - const int token_offset_now = write_idx; - const int dst_token_idx = dst_chunk_start_idx + token_offset_now; - permute_indices_per_token[expert_now * num_rows + s_token_idx] = dst_token_idx; - dst_weights[dst_token_idx] = topk_weights[s_token_idx * moe_topk + expert_idx]; - dst_indices[s_token_idx * NUM_EXPERTS_PER_RANK + expert_now] = expert_now; - // cp x - for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { - Load(&src_x[s_token_idx * hidden_size + v_id * vec_size], &src_vec); - if (up_gate_proj_in_scale) { - for (int i = 0; i < vec_size; i++) { - float quant_value = max_bound * up_gate_proj_in_scale[expert_now] * static_cast(src_vec[i]); - if constexpr (std::is_same::value) { - // w4aint8 - if (RoundType == 0) { - res_vec[i] = static_cast(ClipFunc(rint(quant_value), min_bound, max_bound)); - } else { - res_vec[i] = static_cast(ClipFunc(round(quant_value), min_bound, max_bound)); - } + } + abs_max = phi::BlockAllReduce(abs_max); + float scale = 440.f / abs_max; // use 440 so we do not have to clip + dequant_scale[dst_token_idx] = abs_max; + for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { + Load(&data_smem[v_id * vec_size], &src_vec); +#pragma unroll + for (int i = 0; i < vec_size; i++) { + float quant_value = scale * static_cast(src_vec[i]); + // dynamic quant only supporet for w4afp8 + res_vec[i] = static_cast(quant_value); + } + Store( + res_vec, + &permute_x[dst_token_idx * hidden_size + v_id * vec_size]); + } + } else { // static quant or not quant + for (int v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { + Load(&src_x[s_token_idx * hidden_size + v_id * vec_size], + &src_vec); + if (up_gate_proj_in_scale) { +#pragma unroll + for (int i = 0; i < vec_size; i++) { + float quant_value = max_bound * + up_gate_proj_in_scale[expert_now] * + static_cast(src_vec[i]); + if constexpr (std::is_same::value) { + // w4aint8 + if (RoundType == 0) { + res_vec[i] = static_cast( + ClipFunc(rint(quant_value), min_bound, max_bound)); } else { - // w4afp8 - float value = ClipFunc(quant_value, min_bound, max_bound); - res_vec[i] = static_cast(value); + res_vec[i] = static_cast(ClipFunc( + round(quant_value), min_bound, max_bound)); } - } - } else { - for (int i = 0; i < vec_size; i++) { - res_vec[i] = static_cast(src_vec[i]); + } else { + // w4afp8 + float value = + ClipFunc(quant_value, min_bound, max_bound); + res_vec[i] = static_cast(value); } } - Store(res_vec, &permute_x[dst_token_idx * hidden_size + v_id * vec_size]); + } else { +#pragma unroll + for (int i = 0; i < vec_size; i++) { + res_vec[i] = static_cast(src_vec[i]); + } } - expert_idx_per_token[dst_token_idx] = expert_now; + Store( + res_vec, + &permute_x[dst_token_idx * hidden_size + v_id * vec_size]); } + } + expert_idx_per_token[dst_token_idx] = expert_now; } + } } template -void EPMoeDispatchKernel(const paddle::Tensor& input, - const paddle::Tensor& topk_ids, - const paddle::Tensor& topk_weights, - const paddle::Tensor& token_nums_per_expert, - const paddle::optional& up_gate_proj_in_scale, - const std::string& moe_quant_type, - const int moe_topk, - const int num_rows, - const int token_nums_this_rank, - const int hidden_size, - const int num_experts_per_rank, - paddle::Tensor* permute_input, - paddle::Tensor* permute_indices_per_token, - paddle::Tensor* dst_weights, - paddle::Tensor* dst_indices, - paddle::Tensor* cumsum_idx_gpu, - paddle::Tensor* token_nums_per_expert_cumsum, - paddle::Tensor* expert_idx_per_token) { +void EPMoeDispatchKernel( + const paddle::Tensor& input, + const paddle::Tensor& topk_ids, + const paddle::Tensor& topk_weights, + const paddle::Tensor& token_nums_per_expert, + const paddle::optional& up_gate_proj_in_scale, + const std::string& moe_quant_type, + const int moe_topk, + const int num_rows, + const int token_nums_this_rank, + const int hidden_size, + const int num_experts_per_rank, + paddle::Tensor* permute_input, + paddle::Tensor* permute_indices_per_token, + paddle::Tensor* dst_weights, + paddle::Tensor* dst_indices, + paddle::Tensor* cumsum_idx_gpu, + paddle::Tensor* token_nums_per_expert_cumsum, + paddle::Tensor* expert_idx_per_token, + paddle::Tensor* dequant_scale) { using namespace phi; typedef PDTraits traits_; @@ -453,75 +515,93 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, auto place = input.place(); const int gridx = min(132 * 8, num_rows); if (moe_quant_type == "w4a8") { - DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - );) + DISPATCH_NUM_EXPERTS_PER_RANK( + num_experts_per_rank, + NUM_EXPERTS_PER_RANK, + permute_x_kernel + <<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() + : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + nullptr, // dequant_scale + 127.0, + -127.0);) } else if (moe_quant_type == "w4afp8") { - DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 448.0f, - -448.0f - );) + const int smem_size = + up_gate_proj_in_scale ? 0 : hidden_size * sizeof(data_t); + DISPATCH_NUM_EXPERTS_PER_RANK( + num_experts_per_rank, + NUM_EXPERTS_PER_RANK, + permute_x_kernel + <<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() + : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + up_gate_proj_in_scale + ? nullptr + : dequant_scale + ->data(), // up_gate_proj_in_scale is used for + // static quant, while dequant_scale is + // used for dynamic quant + 448.0f, + -448.0f);) } else { - DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - );) + DISPATCH_NUM_EXPERTS_PER_RANK( + num_experts_per_rank, + NUM_EXPERTS_PER_RANK, + permute_x_kernel + <<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() + : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + nullptr, // dequant scale + 127.0, + -127.0);) } } - std::vector EPMoeExpertDispatch( const paddle::Tensor& input, const paddle::Tensor& topk_ids, @@ -546,64 +626,86 @@ std::vector EPMoeExpertDispatch( const int num_experts_per_rank = token_nums_per_expert.size(); auto permute_input = GetEmptyTensor( - {token_nums_this_rank, hidden_size}, - moe_quant_type == "w4a8" ? paddle::DataType::INT8 : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN : input_type, - place); - auto num_experts_per_rank_tensor = GetEmptyTensor( - {num_experts_per_rank}, - paddle::DataType::INT32, - place); - auto expert_idx_per_token = GetEmptyTensor( - {token_nums_this_rank}, paddle::DataType::INT64, place); - cudaMemcpyAsync(num_experts_per_rank_tensor.data(), token_nums_per_expert.data(), num_experts_per_rank * sizeof(int), cudaMemcpyHostToDevice, input.stream()); - // cudaMemcpy(num_experts_per_rank_tensor.data(), token_nums_per_expert.data(), num_experts_per_rank * sizeof(int), cudaMemcpyHostToDevice); - auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); - auto dst_weights = GetEmptyTensor({token_nums_this_rank}, paddle::DataType::FLOAT32, place); - auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); - auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); - auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); + {token_nums_this_rank, hidden_size}, + moe_quant_type == "w4a8" ? paddle::DataType::INT8 + : moe_quant_type == "w4afp8" ? paddle::DataType::FLOAT8_E4M3FN + : input_type, + place); + auto num_experts_per_rank_tensor = + GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT32, place); + auto expert_idx_per_token = + GetEmptyTensor({token_nums_this_rank}, paddle::DataType::INT64, place); + cudaMemcpyAsync(num_experts_per_rank_tensor.data(), + token_nums_per_expert.data(), + num_experts_per_rank * sizeof(int), + cudaMemcpyHostToDevice, + input.stream()); + // cudaMemcpy(num_experts_per_rank_tensor.data(), + // token_nums_per_expert.data(), num_experts_per_rank * sizeof(int), + // cudaMemcpyHostToDevice); + auto token_nums_per_expert_cumsum = + GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); + auto dst_weights = + GetEmptyTensor({token_nums_this_rank}, paddle::DataType::FLOAT32, place); + auto dst_indices = GetEmptyTensor( + {num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); + auto permute_indices_per_token = paddle::full( + {num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); + auto cumsum_idx_gpu = + paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); + int dequant_scale_size = 1; + if (moe_quant_type == "w4afp8" && !up_gate_proj_in_scale) { + dequant_scale_size = moe_topk * num_rows; + } + + auto dequant_scale = + GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place); switch (input_type) { case paddle::DataType::BFLOAT16: - EPMoeDispatchKernel(input, - topk_ids, - topk_weights, - num_experts_per_rank_tensor, - up_gate_proj_in_scale, - moe_quant_type, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - num_experts_per_rank, - &permute_input, - &permute_indices_per_token, - &dst_weights, - &dst_indices, - &cumsum_idx_gpu, - &token_nums_per_expert_cumsum, - &expert_idx_per_token); + EPMoeDispatchKernel( + input, + topk_ids, + topk_weights, + num_experts_per_rank_tensor, + up_gate_proj_in_scale, + moe_quant_type, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + num_experts_per_rank, + &permute_input, + &permute_indices_per_token, + &dst_weights, + &dst_indices, + &cumsum_idx_gpu, + &token_nums_per_expert_cumsum, + &expert_idx_per_token, + &dequant_scale); break; case paddle::DataType::FLOAT16: - EPMoeDispatchKernel(input, - topk_ids, - topk_weights, - num_experts_per_rank_tensor, - up_gate_proj_in_scale, - moe_quant_type, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - num_experts_per_rank, - &permute_input, - &permute_indices_per_token, - &dst_weights, - &dst_indices, - &cumsum_idx_gpu, - &token_nums_per_expert_cumsum, - &expert_idx_per_token); + EPMoeDispatchKernel( + input, + topk_ids, + topk_weights, + num_experts_per_rank_tensor, + up_gate_proj_in_scale, + moe_quant_type, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + num_experts_per_rank, + &permute_input, + &permute_indices_per_token, + &dst_weights, + &dst_indices, + &cumsum_idx_gpu, + &token_nums_per_expert_cumsum, + &expert_idx_per_token, + &dequant_scale); break; default: PD_THROW("Unsupported data type for EPMoeDispatchKernel"); @@ -614,10 +716,10 @@ std::vector EPMoeExpertDispatch( dst_weights, dst_indices, cumsum_idx_gpu, - expert_idx_per_token}; + expert_idx_per_token, + dequant_scale}; } - std::vector> EPMoeExpertDispatchInferShape( const std::vector& input_shape, const std::vector& topk_ids_shape, @@ -632,7 +734,7 @@ std::vector> EPMoeExpertDispatchInferShape( } else { token_rows = input_shape[0]; } - const int expert_num = token_nums_per_expert.size(); // 本地专家个数 + const int expert_num = token_nums_per_expert.size(); // 本地专家个数 const int num_rows = token_rows; const int hidden_size = input_shape[input_shape.size() - 1]; @@ -642,7 +744,8 @@ std::vector> EPMoeExpertDispatchInferShape( {token_nums_this_rank}, {num_rows, expert_num}, {expert_num}, - {token_nums_this_rank}}; // dst_idx per expert + {token_nums_this_rank}, // dst_idx per expert + {token_nums_this_rank}}; } std::vector EPMoeExpertDispatchInferDtype( @@ -658,12 +761,14 @@ std::vector EPMoeExpertDispatchInferDtype( paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32, - paddle::DataType::INT64}; + paddle::DataType::INT64, + paddle::DataType::FLOAT32}; } - PD_BUILD_STATIC_OP(ep_moe_expert_dispatch) - .Inputs({"input", "topk_ids", "topk_weights", + .Inputs({"input", + "topk_ids", + "topk_weights", paddle::Optional("up_gate_proj_in_scale")}) .Outputs({"permute_input", "permute_indices_per_token", @@ -671,105 +776,120 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch) "dst_weights", "dst_indices", "cumsum_idx_gpu", - "expert_idx_per_token"}) - .Attrs({ - "token_nums_per_expert: std::vector", - "token_nums_this_rank: int", - "moe_quant_type: std::string" - }) + "expert_idx_per_token", + "dequant_scale"}) + .Attrs({"token_nums_per_expert: std::vector", + "token_nums_this_rank: int", + "moe_quant_type: std::string"}) .SetKernelFn(PD_KERNEL(EPMoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(EPMoeExpertDispatchInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(EPMoeExpertDispatchInferDtype)); - template -__global__ void permute_x_fp8_kernel(const T *src_x, - const float *scale, - const int64_t *topk_idx, - const float *topk_weights, - const int *token_nums_per_expert, - const int *token_nums_per_expert_padded, - const int moe_topk, - const int num_rows, - const int token_nums_this_rank, - const int token_nums_this_rank_padded, - const int64_t hidden_size, - T *permute_x, // [token_nums_this_rank, hidden_size] - float *permute_scale, - int *permute_indices_per_token, // [moe_topk, num_rows] - float *dst_weights, // [token_nums_this_rank] - int *dst_indices, - int *cumsum_idx_gpu, - int64_t *token_nums_per_expert_cumsum, - int64_t *token_nums_per_expert_padded_cumsum, - int *m_indices) { // [num_rows, moe_topk] - const int64_t src_token_idx = blockIdx.x; - const int tid = threadIdx.x; - constexpr int vec_size = sizeof(int4) / sizeof(T); - constexpr int scale_vec_size = sizeof(int4) / sizeof(float); - __shared__ int write_idx; // cumsum start idx - __shared__ int token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK]; - if (tid == 0) { - int sum_now = 0; - int sum_now_padded = 0; - for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) { - sum_now += token_nums_per_expert[i]; - sum_now_padded += token_nums_per_expert_padded[i]; - token_nums_per_expert_cum[i] = sum_now_padded; - if (blockIdx.x == 0) { - token_nums_per_expert_cumsum[i] = sum_now; - token_nums_per_expert_padded_cumsum[i] = sum_now_padded; - } +__global__ void permute_x_fp8_kernel( + const T* src_x, + const float* scale, + const int64_t* topk_idx, + const float* topk_weights, + const int* token_nums_per_expert, + const int* token_nums_per_expert_padded, + const int moe_topk, + const int num_rows, + const int token_nums_this_rank, + const int token_nums_this_rank_padded, + const int64_t hidden_size, + T* permute_x, // [token_nums_this_rank, hidden_size] + float* permute_scale, + int* permute_indices_per_token, // [moe_topk, num_rows] + float* dst_weights, // [token_nums_this_rank] + int* dst_indices, + int* cumsum_idx_gpu, + int64_t* token_nums_per_expert_cumsum, + int64_t* token_nums_per_expert_padded_cumsum, + int* m_indices) { // [num_rows, moe_topk] + const int64_t src_token_idx = blockIdx.x; + const int tid = threadIdx.x; + constexpr int vec_size = sizeof(int4) / sizeof(T); + constexpr int scale_vec_size = sizeof(int4) / sizeof(float); + __shared__ int write_idx; // cumsum start idx + __shared__ int token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK]; + if (tid == 0) { + int sum_now = 0; + int sum_now_padded = 0; + for (int i = 0; i < NUM_EXPERTS_PER_RANK; i++) { + sum_now += token_nums_per_expert[i]; + sum_now_padded += token_nums_per_expert_padded[i]; + token_nums_per_expert_cum[i] = sum_now_padded; + if (blockIdx.x == 0) { + token_nums_per_expert_cumsum[i] = sum_now; + token_nums_per_expert_padded_cumsum[i] = sum_now_padded; } } - __syncthreads(); - const int hidden_size_int4 = hidden_size / vec_size; - const int hidden_size_scale = hidden_size / 128; - const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size; - const int token_nums_feed_to_ffn = token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK-1]; - // prmt - for (int64_t s_token_idx = src_token_idx; s_token_idx < token_nums_feed_to_ffn; s_token_idx += gridDim.x) { - - // the m_indices[s_token_idx] must be a value `i` in [0, NUM_EXPERTS_PER_RANK) - // here we parallel wo find the `i` we want. - for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i+= blockDim.x) { - const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1]; - const int end_idx = token_nums_per_expert_cum[i]; - if (s_token_idx >= start_idx && s_token_idx < end_idx) { - if ((s_token_idx - start_idx) < token_nums_per_expert[i]) m_indices[s_token_idx] = i; - break; - } + } + __syncthreads(); + const int hidden_size_int4 = hidden_size / vec_size; + const int hidden_size_scale = hidden_size / 128; + const int hidden_size_scale_int4 = hidden_size_scale / scale_vec_size; + const int token_nums_feed_to_ffn = + token_nums_per_expert_cum[NUM_EXPERTS_PER_RANK - 1]; + // prmt + for (int64_t s_token_idx = src_token_idx; + s_token_idx < token_nums_feed_to_ffn; + s_token_idx += gridDim.x) { + // the m_indices[s_token_idx] must be a value `i` in [0, + // NUM_EXPERTS_PER_RANK) here we parallel wo find the `i` we want. + for (int i = threadIdx.x; i < NUM_EXPERTS_PER_RANK; i += blockDim.x) { + const int start_idx = i == 0 ? 0 : token_nums_per_expert_cum[i - 1]; + const int end_idx = token_nums_per_expert_cum[i]; + if (s_token_idx >= start_idx && s_token_idx < end_idx) { + if ((s_token_idx - start_idx) < token_nums_per_expert[i]) + m_indices[s_token_idx] = i; + break; } + } - if (s_token_idx < num_rows) { - const int64_t *topk_idx_now = topk_idx + s_token_idx * moe_topk; + if (s_token_idx < num_rows) { + const int64_t* topk_idx_now = topk_idx + s_token_idx * moe_topk; #pragma unroll - for (int expert_idx = 0; expert_idx < moe_topk; expert_idx++) { - int expert_now = static_cast(topk_idx_now[expert_idx]); - if (expert_now == -1) continue; - const int dst_chunk_start_idx = expert_now == 0 ? 0 : token_nums_per_expert_cum[expert_now - 1]; - if (tid == 0) { - const int offset_now = atomicAdd(cumsum_idx_gpu + expert_now, 1); - write_idx = offset_now; - } - __syncthreads(); - const int token_offset_now = write_idx; - const int64_t dst_token_idx = dst_chunk_start_idx + token_offset_now; - permute_indices_per_token[expert_now * num_rows + s_token_idx] = dst_token_idx; - dst_weights[dst_token_idx] = topk_weights[s_token_idx * moe_topk + expert_idx]; - // m_indices[dst_token_idx] = expert_now; // not need? - dst_indices[s_token_idx * NUM_EXPERTS_PER_RANK + expert_now] = expert_now; - // cp x - for (int64_t v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { - *(reinterpret_cast(permute_x + dst_token_idx * hidden_size) + v_id) = *(reinterpret_cast(src_x + s_token_idx * hidden_size) + v_id); - } - // cp scale - for (int v_id = tid; v_id < hidden_size_scale_int4; v_id += blockDim.x) { - *(reinterpret_cast(permute_scale + dst_token_idx * hidden_size_scale) + v_id) = *(reinterpret_cast(scale + s_token_idx * hidden_size_scale) + v_id); - } + for (int expert_idx = 0; expert_idx < moe_topk; expert_idx++) { + int expert_now = static_cast(topk_idx_now[expert_idx]); + if (expert_now == -1) continue; + const int dst_chunk_start_idx = + expert_now == 0 ? 0 : token_nums_per_expert_cum[expert_now - 1]; + if (tid == 0) { + const int offset_now = atomicAdd(cumsum_idx_gpu + expert_now, 1); + write_idx = offset_now; + } + __syncthreads(); + const int token_offset_now = write_idx; + const int64_t dst_token_idx = dst_chunk_start_idx + token_offset_now; + permute_indices_per_token[expert_now * num_rows + s_token_idx] = + dst_token_idx; + dst_weights[dst_token_idx] = + topk_weights[s_token_idx * moe_topk + expert_idx]; + // m_indices[dst_token_idx] = expert_now; // not need? + dst_indices[s_token_idx * NUM_EXPERTS_PER_RANK + expert_now] = + expert_now; + // cp x + for (int64_t v_id = tid; v_id < hidden_size_int4; v_id += blockDim.x) { + *(reinterpret_cast(permute_x + dst_token_idx * hidden_size) + + v_id) = *(reinterpret_cast(src_x + + s_token_idx * hidden_size) + + v_id); + } + // cp scale + for (int v_id = tid; v_id < hidden_size_scale_int4; + v_id += blockDim.x) { + *(reinterpret_cast(permute_scale + + dst_token_idx * hidden_size_scale) + + v_id) = + *(reinterpret_cast(scale + + s_token_idx * hidden_size_scale) + + v_id); } } } + } } void EPMoeDispatchFP8Kernel(const paddle::Tensor& input, @@ -797,33 +917,33 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input, auto place = input.place(); // const int gridx = min(132 * 8, num_rows); const int gridx = 132 * 8; - DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, - permute_x_fp8_kernel<<>>( - input.data(), - scale.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - token_nums_per_expert_padded.data(), - moe_topk, - num_rows, - token_nums_this_rank, - token_nums_this_rank_padded, - hidden_size, - permute_input->data(), - permute_scale->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - token_nums_per_expert_padded_cumsum->data(), - m_indices->data() - );) - + DISPATCH_NUM_EXPERTS_PER_RANK( + num_experts_per_rank, + NUM_EXPERTS_PER_RANK, + permute_x_fp8_kernel + <<>>( + input.data(), + scale.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + token_nums_per_expert_padded.data(), + moe_topk, + num_rows, + token_nums_this_rank, + token_nums_this_rank_padded, + hidden_size, + permute_input->data(), + permute_scale->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + token_nums_per_expert_padded_cumsum->data(), + m_indices->data());) } - std::vector EPMoeExpertDispatchFP8( const paddle::Tensor& input, const paddle::Tensor& scale, @@ -848,46 +968,53 @@ std::vector EPMoeExpertDispatchFP8( const int hidden_size = input.dims()[input_dims.size() - 1]; const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0]; - int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1); + int32_t token_nums_feed_to_ffn = + use_in_ep ? token_nums_this_rank_padded + : token_rows * moe_topk + num_experts_per_rank * (128 - 1); - auto permute_input = GetEmptyTensor( - {token_nums_feed_to_ffn, hidden_size}, - input_type, - place); - auto permute_scale = GetEmptyTensor( - {token_nums_feed_to_ffn, hidden_size / 128}, - paddle::DataType::FLOAT32, - place); + auto permute_input = + GetEmptyTensor({token_nums_feed_to_ffn, hidden_size}, input_type, place); + auto permute_scale = + GetEmptyTensor({token_nums_feed_to_ffn, hidden_size / 128}, + paddle::DataType::FLOAT32, + place); - auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place); - auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); - auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); - auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place); - auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); - auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); - auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); + auto m_indices = paddle::full( + {token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place); + auto token_nums_per_expert_cumsum = + GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); + auto token_nums_per_expert_padded_cumsum = + GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); + auto dst_weights = GetEmptyTensor( + {token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place); + auto dst_indices = GetEmptyTensor( + {num_rows, num_experts_per_rank}, paddle::DataType::INT32, place); + auto permute_indices_per_token = paddle::full( + {num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place); + auto cumsum_idx_gpu = + paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place); EPMoeDispatchFP8Kernel(input, - scale, - topk_ids, - topk_weights, - num_experts_per_rank_tensor, - num_experts_per_rank_padded_tensor, - moe_topk, - num_rows, - -1, - -1, - hidden_size, - num_experts_per_rank, - &permute_input, - &permute_scale, - &permute_indices_per_token, - &dst_weights, - &dst_indices, - &cumsum_idx_gpu, - &token_nums_per_expert_cumsum, - &token_nums_per_expert_padded_cumsum, - &m_indices); + scale, + topk_ids, + topk_weights, + num_experts_per_rank_tensor, + num_experts_per_rank_padded_tensor, + moe_topk, + num_rows, + -1, + -1, + hidden_size, + num_experts_per_rank, + &permute_input, + &permute_scale, + &permute_indices_per_token, + &dst_weights, + &dst_indices, + &cumsum_idx_gpu, + &token_nums_per_expert_cumsum, + &token_nums_per_expert_padded_cumsum, + &m_indices); return {permute_input, permute_scale, permute_indices_per_token, @@ -900,7 +1027,12 @@ std::vector EPMoeExpertDispatchFP8( } PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8) - .Inputs({"input", "scale", "topk_ids", "topk_weights", "num_experts_per_rank_tensor", "num_experts_per_rank_padded_tensor"}) + .Inputs({"input", + "scale", + "topk_ids", + "topk_weights", + "num_experts_per_rank_tensor", + "num_experts_per_rank_padded_tensor"}) .Outputs({"permute_input", "permute_scale", "permute_indices_per_token", diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 703a7c11f..817e9d92b 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -25,7 +25,8 @@ template __global__ void moe_token_type_ids_kernel(T *gating_output, const int *moe_token_type_ids_out, const int num_rows, - const int num_experts, const int k) { + const int num_experts, + const int k) { const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x; if (moe_token_index >= num_rows) { @@ -44,7 +45,8 @@ template void moe_token_type_ids_kernelLauncher(T *gating_output, const int *moe_token_type_ids_out, const int num_rows, - const int num_experts, const int k, + const int num_experts, + const int k, cudaStream_t stream) { const int blocks = num_rows * k / 512 + 1; const int threads = 512; @@ -52,26 +54,35 @@ void moe_token_type_ids_kernelLauncher(T *gating_output, gating_output, moe_token_type_ids_out, num_rows, num_experts, k); } -template class MoeHelper { -public: - using Fp16Traits = cutlass::WintQuantTraits; - using Int8Traits = cutlass::WintQuantTraits; - using Int4Traits = cutlass::WintQuantTraits; +template +class MoeHelper { + public: + using Fp16Traits = + cutlass::WintQuantTraits; + using Int8Traits = + cutlass::WintQuantTraits; + using Int4Traits = + cutlass::WintQuantTraits; - MoeHelper( - const std::string gemm_method, - MoeGemmRunner *fp16_moe_gemm_runner, - MoeGemmRunner *int8_moe_gemm_runner, - MoeGemmRunner *int4_moe_gemm_runner, - int layernum = 0) - : gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner), + MoeHelper(const std::string gemm_method, + MoeGemmRunner *fp16_moe_gemm_runner, + MoeGemmRunner *int8_moe_gemm_runner, + MoeGemmRunner *int4_moe_gemm_runner, + int layernum = 0) + : gemm_method_(gemm_method), + fp16_moe_gemm_runner_(fp16_moe_gemm_runner), int8_moe_gemm_runner_(int8_moe_gemm_runner), - int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {} + int4_moe_gemm_runner_(int4_moe_gemm_runner), + layernum_(layernum) {} // -------- getWorkspaceSize -------- // template - size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size, - const int64_t inter_size, const int64_t num_experts, + size_t getWorkspaceSize(const int64_t num_rows, + const int64_t hidden_size, + const int64_t inter_size, + const int64_t num_experts, const int64_t k) { const size_t buf_size = AlignTo16(k * num_rows * hidden_size); const size_t interbuf_size = AlignTo16(k * num_rows * inter_size); @@ -82,10 +93,10 @@ public: // FfnLayer forward. size_t total_ws_bytes = 5 * num_moe_inputs * - sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ - total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data + sizeof(int); // source_rows_, permuted_rows_, permuted_experts_ + total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data total_ws_bytes += - padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ + padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT); const size_t sorter_ws_size_bytes = @@ -100,8 +111,8 @@ public: } total_ws_bytes += - bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub - // sorting workspace + bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub + // sorting workspace int64_t num_softmax_outs = 0; const bool is_pow_2 = @@ -115,20 +126,27 @@ public: return total_ws_bytes; } - void - ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight, - const paddle::Tensor *up_gate_proj_weight, - const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias, - const paddle::Tensor *down_proj_weight, - const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias, - const paddle::Tensor *moe_token_type_ids, const int moe_topk, - const bool group_moe, const bool norm_topk_prob, - const float routed_scaling_factor, const std::string moe_type, - paddle::Tensor *output) { + void ComputeFFN(const paddle::Tensor *input, + const paddle::Tensor *gate_weight, + const paddle::Tensor *up_gate_proj_weight, + const paddle::Tensor *up_gate_proj_scale, + const paddle::Tensor *up_gate_proj_bias, + const paddle::Tensor *down_proj_weight, + const paddle::Tensor *down_proj_scale, + const paddle::Tensor *down_proj_bias, + const paddle::Tensor *moe_token_type_ids, + const int moe_topk, + const bool group_moe, + const bool norm_topk_prob, + const float routed_scaling_factor, + const std::string moe_type, + paddle::Tensor *output) { auto *input_activations = input->data(); auto *gating_weights = gate_weight->data(); - const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; - const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data() : nullptr; + const T *fc1_expert_biases = + up_gate_proj_bias ? up_gate_proj_bias->data() : nullptr; + const T *fc2_expert_biases = + down_proj_bias ? down_proj_bias->data() : nullptr; auto *output_ = output->data(); auto stream = input->stream(); @@ -148,7 +166,8 @@ public: const int64_t hidden_size = up_gate_proj_dims[1]; int64_t inter_dim = 0; if (moe_type == "qkv") { - inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; + inter_dim = + up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4]; } else { inter_dim = up_gate_proj_dims[2]; } @@ -232,44 +251,79 @@ public: if (moe_token_type_ids) { auto *moe_token_type_ids_out = moe_token_type_ids->data(); moe_token_type_ids_kernelLauncher(gating_output, - moe_token_type_ids_out, num_rows, - num_experts, k, stream); + moe_token_type_ids_out, + num_rows, + num_experts, + k, + stream); } - topk_gating_softmax_kernelLauncher( - gating_output, nullptr, expert_scales_float, softmax_out_, - expert_for_source_row, source_rows_, softmax_max_prob, num_rows, - num_experts, k, group_moe, stream); + topk_gating_softmax_kernelLauncher(gating_output, + nullptr, + expert_scales_float, + softmax_out_, + expert_for_source_row, + source_rows_, + softmax_max_prob, + num_rows, + num_experts, + k, + group_moe, + stream); const int64_t sorter_ws_size_bytes = AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows))); - sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row, - permuted_experts_, source_rows_, permuted_rows_, k * num_rows, - false, stream); + sorter_.run(fc1_result_, + sorter_ws_size_bytes, + expert_for_source_row, + permuted_experts_, + source_rows_, + permuted_rows_, + k * num_rows, + false, + stream); initialize_moe_routing_kernelLauncher( - input_activations, permuted_data_, permuted_rows_, nullptr, nullptr, - expanded_source_row_to_expanded_dest_row, num_rows, num_rows, - hidden_size, k, stream); + input_activations, + permuted_data_, + permuted_rows_, + nullptr, + nullptr, + expanded_source_row_to_expanded_dest_row, + nullptr, + num_rows, + num_rows, + hidden_size, + k, + stream); const int64_t expanded_active_expert_rows = k * num_rows; compute_total_rows_before_expert(permuted_experts_, - expanded_active_expert_rows, num_experts, - total_rows_before_expert_, stream); + expanded_active_expert_rows, + num_experts, + total_rows_before_expert_, + stream); if (gemm_method_ == "weight_only_int8") { typename Int8Traits::Arguments up_gate_proj_quant_args; int8_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast(up_gate_proj_weight->data()), + reinterpret_cast( + up_gate_proj_weight->data()), reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, inter_size, hidden_size, num_experts, - up_gate_proj_quant_args, "none", stream); + reinterpret_cast(fc1_out), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + inter_size, + hidden_size, + num_experts, + up_gate_proj_quant_args, + "none", + stream); } else if (gemm_method_ == "weight_only_int4") { typename Int4Traits::Arguments up_gate_proj_quant_args; int4_moe_gemm_runner_->moe_gemm_bias_act( @@ -278,20 +332,33 @@ public: up_gate_proj_weight->data()), reinterpret_cast(up_gate_proj_scale->data()), reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, inter_size, hidden_size, num_experts, - up_gate_proj_quant_args, "none", stream); + reinterpret_cast(fc1_out), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + inter_size, + hidden_size, + num_experts, + up_gate_proj_quant_args, + "none", + stream); } else { typename Fp16Traits::Arguments up_gate_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm_bias_act( reinterpret_cast(permuted_data_), - reinterpret_cast(up_gate_proj_weight->data()), nullptr, + reinterpret_cast(up_gate_proj_weight->data()), + nullptr, reinterpret_cast(fc1_expert_biases), - reinterpret_cast(fc1_out), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, inter_size, hidden_size, num_experts, - up_gate_proj_quant_args, "none", stream); + reinterpret_cast(fc1_out), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + inter_size, + hidden_size, + num_experts, + up_gate_proj_quant_args, + "none", + stream); } if (moe_type == "ffn") { @@ -309,10 +376,15 @@ public: reinterpret_cast(act_out), reinterpret_cast(down_proj_weight->data()), reinterpret_cast(down_proj_scale->data()), - reinterpret_cast(fc2_result), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, down_proj_quant_args, stream); + reinterpret_cast(fc2_result), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + hidden_size, + inter_size / 2, + num_experts, + down_proj_quant_args, + stream); } else if (gemm_method_ == "weight_only_int4") { typename Int4Traits::Arguments down_proj_quant_args; int4_moe_gemm_runner_->moe_gemm( @@ -320,40 +392,66 @@ public: reinterpret_cast( down_proj_weight->data()), reinterpret_cast(down_proj_scale->data()), - reinterpret_cast(fc2_result), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, down_proj_quant_args, stream); + reinterpret_cast(fc2_result), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + hidden_size, + inter_size / 2, + num_experts, + down_proj_quant_args, + stream); } else { typename Fp16Traits::Arguments down_proj_quant_args; fp16_moe_gemm_runner_->moe_gemm( reinterpret_cast(act_out), - reinterpret_cast(down_proj_weight->data()), nullptr, - reinterpret_cast(fc2_result), total_rows_before_expert_, - -1, // useless - expanded_active_expert_rows, hidden_size, inter_size / 2, - num_experts, down_proj_quant_args, stream); + reinterpret_cast(down_proj_weight->data()), + nullptr, + reinterpret_cast(fc2_result), + total_rows_before_expert_, + -1, // useless + expanded_active_expert_rows, + hidden_size, + inter_size / 2, + num_experts, + down_proj_quant_args, + stream); } finalize_moe_routing_kernelLauncher( - fc2_result, output_, fc2_expert_biases, + fc2_result, + output_, + fc2_expert_biases, reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, expert_for_source_row, - num_rows, hidden_size, k, static_cast(1), norm_topk_prob, - routed_scaling_factor, stream); + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + hidden_size, + k, + static_cast(1), + norm_topk_prob, + routed_scaling_factor, + stream); } else { finalize_moe_routing_kernelLauncher( // fc2_result, - fc1_out, output_, - fc1_expert_biases, // fc2_expert_biases, + fc1_out, + output_, + fc1_expert_biases, // fc2_expert_biases, reinterpret_cast(expert_scales_float), - expanded_source_row_to_expanded_dest_row, expert_for_source_row, - num_rows, inter_size, k, static_cast(0), norm_topk_prob, - routed_scaling_factor, stream); + expanded_source_row_to_expanded_dest_row, + expert_for_source_row, + num_rows, + inter_size, + k, + static_cast(0), + norm_topk_prob, + routed_scaling_factor, + stream); } } -private: + private: std::string gemm_method_; MoeGemmRunner *fp16_moe_gemm_runner_; MoeGemmRunner *int8_moe_gemm_runner_; @@ -362,4 +460,4 @@ private: CubKeyValueSorter sorter_; }; -} // namespace phi +} // namespace phi diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index eeaecb716..def1aab93 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -19,9 +19,9 @@ #include #include -#include "moe/fused_moe_imp_op.h" -#include "moe/fused_moe_helper.h" #include "cutlass/numeric_conversion.h" +#include "moe/fused_moe_helper.h" +#include "moe/fused_moe_imp_op.h" // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -37,8 +37,8 @@ namespace phi { struct GpuLaunchConfig { - dim3 block_per_grid; - dim3 thread_per_block; + dim3 block_per_grid; + dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { @@ -59,55 +59,59 @@ inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; template -__host__ __device__ constexpr static U arrayConvert(T const& input) -{ - using Type = typename U::Element; - static_assert(T::kElements == U::kElements); - U u; +__host__ __device__ constexpr static U arrayConvert(T const& input) { + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; #pragma unroll - for (int i = 0; i < U::kElements; i++) - { - u[i] = static_cast(input[i]); - } - return u; + for (int i = 0; i < U::kElements; i++) { + u[i] = static_cast(input[i]); + } + return u; } struct uint8 { - uint4 u; - uint4 v; + uint4 u; + uint4 v; }; -template struct BytesToType {}; +template +struct BytesToType {}; -template<> +template <> struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); + using Type = uint8; + static_assert(sizeof(Type) == 32); }; -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); +template <> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); }; -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); +template <> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); }; -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); +template <> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); }; -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); +template <> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); }; -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); +template <> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); }; template