diff --git a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu index 60ae7d1fc..46cc60bef 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_prefill_func.cu @@ -25,6 +25,46 @@ #include "helper.h" #include + +#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \ + switch (num_experts_per_rank) { \ + case 8: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 9: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 9; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 160: { \ + constexpr size_t NUM_EXPERTS_PER_RANK = 160; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \ + throw std::invalid_argument(err_msg.str()); \ + } \ + } + namespace cg = cooperative_groups; template @@ -743,8 +783,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input, auto place = input.place(); // const int gridx = min(132 * 8, num_rows); const int gridx = 132 * 8; - if (num_experts_per_rank == 8) { - permute_x_fp8_kernel<<>>( + DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, + permute_x_fp8_kernel<<>>( input.data(), scale.data(), topk_ids.data(), @@ -765,102 +805,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input, token_nums_per_expert_cumsum->data(), token_nums_per_expert_padded_cumsum->data(), m_indices->data() - ); - } else if (num_experts_per_rank == 9) { - permute_x_fp8_kernel<<>>( - input.data(), - scale.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - token_nums_per_expert_padded.data(), - moe_topk, - num_rows, - token_nums_this_rank, - token_nums_this_rank_padded, - hidden_size, - permute_input->data(), - permute_scale->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - token_nums_per_expert_padded_cumsum->data(), - m_indices->data() - ); - } else if (num_experts_per_rank == 16) { - permute_x_fp8_kernel<<>>( - input.data(), - scale.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - token_nums_per_expert_padded.data(), - moe_topk, - num_rows, - token_nums_this_rank, - token_nums_this_rank_padded, - hidden_size, - permute_input->data(), - permute_scale->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - token_nums_per_expert_padded_cumsum->data(), - m_indices->data() - ); - } else if (num_experts_per_rank == 64) { - permute_x_fp8_kernel<<>>( - input.data(), - scale.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - token_nums_per_expert_padded.data(), - moe_topk, - num_rows, - token_nums_this_rank, - token_nums_this_rank_padded, - hidden_size, - permute_input->data(), - permute_scale->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - token_nums_per_expert_padded_cumsum->data(), - m_indices->data() - ); - } else if (num_experts_per_rank == 128) { - permute_x_fp8_kernel<<>>( - input.data(), - scale.data(), - topk_ids.data(), - topk_weights.data(), - token_nums_per_expert.data(), - token_nums_per_expert_padded.data(), - moe_topk, - num_rows, - token_nums_this_rank, - token_nums_this_rank_padded, - hidden_size, - permute_input->data(), - permute_scale->data(), - permute_indices_per_token->data(), - dst_weights->data(), - dst_indices->data(), - cumsum_idx_gpu->data(), - token_nums_per_expert_cumsum->data(), - token_nums_per_expert_padded_cumsum->data(), - m_indices->data() - ); - } else { - PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel"); - } + );) + } diff --git a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu index ee27f566c..f9eb4c9ce 100644 --- a/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu +++ b/custom_ops/gpu_ops/moe/tritonmoe_preprocess.cu @@ -168,6 +168,8 @@ std::vector tritonmoe_preprocess_kernel(const paddle::Tensor& to run_align_kernel(64); } else if (num_experts == 128) { run_align_kernel(128); + } else if (num_experts == 160) { + run_align_kernel(160); } else { PD_THROW("Not support num_experts: %d", num_experts); } diff --git a/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py b/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py index b8268ce88..98589a4c3 100644 --- a/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py +++ b/fastdeploy/model_executor/ops/triton_ops/triton_utils_v2.py @@ -134,7 +134,7 @@ class KernelInterface: *args: positional arguments **kwargs: keyword arguments """ - op_name = "haha" + str(kwargs["N"]) + op_name = f'haha_N{str(kwargs["N"])}_K{str(kwargs["K"])}' if op_name in self.func_map.keys(): return self.func_map[op_name](*args)