From 442543cd6bd6de539d0849f13071a451ff7585d6 Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:05:33 +0800 Subject: [PATCH] fix ep wint8 (#4102) --- .../gpu_ops/moe/ep_moe_expert_dispatch.cu | 192 ++++++------------ 1 file changed, 63 insertions(+), 129 deletions(-) diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index 1c3a45e50..fe01400f0 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -448,137 +448,71 @@ void EPMoeDispatchKernel(const paddle::Tensor& input, auto place = input.place(); const int gridx = min(132 * 8, num_rows); if (moe_quant_type == "w4a8") { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 127.0, + -127.0 + );) } else if (moe_quant_type == "w4afp8") { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 448.0f, - -448.0f - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 448.0f, - -448.0f - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 448.0f, + -448.0f + );) } else { - if (num_experts_per_rank == 8) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } else if (num_experts_per_rank == 16) { - permute_x_kernel<<>>( - input.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, - moe_topk, - num_rows, - token_nums_this_rank, - hidden_size, - permute_input->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - expert_idx_per_token->data(), - 127.0, - -127.0 - ); - } + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_kernel<<>>( + input.data(), + topk_ids.data(), + topk_weights.data(), + token_nums_per_expert.data(), + up_gate_proj_in_scale ? up_gate_proj_in_scale.get().data() : nullptr, + moe_topk, + num_rows, + token_nums_this_rank, + hidden_size, + permute_input->data(), + permute_indices_per_token->data(), + dst_weights->data(), + dst_indices->data(), + cumsum_idx_gpu->data(), + token_nums_per_expert_cumsum->data(), + expert_idx_per_token->data(), + 127.0, + -127.0 + );) } }