moe preprocess op support 160 experts and fused_moe triton kernel name add K (#3121)

This commit is contained in:
chen
2025-08-01 10:46:20 +08:00
committed by GitHub
parent 1d93565082
commit a2f5cc54f8
3 changed files with 47 additions and 99 deletions

View File

@@ -25,6 +25,46 @@
#include "helper.h" #include "helper.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
switch (num_experts_per_rank) { \
case 8: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
__VA_ARGS__ \
break; \
} \
case 9: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 9; \
__VA_ARGS__ \
break; \
} \
case 16: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 16; \
__VA_ARGS__ \
break; \
} \
case 64: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
__VA_ARGS__ \
break; \
} \
case 128: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 128; \
__VA_ARGS__ \
break; \
} \
case 160: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 160; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \
throw std::invalid_argument(err_msg.str()); \
} \
}
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template<typename T> template<typename T>
@@ -743,8 +783,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
auto place = input.place(); auto place = input.place();
// const int gridx = min(132 * 8, num_rows); // const int gridx = min(132 * 8, num_rows);
const int gridx = 132 * 8; const int gridx = 132 * 8;
if (num_experts_per_rank == 8) { DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>( permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(), input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(), scale.data<float>(),
topk_ids.data<int64_t>(), topk_ids.data<int64_t>(),
@@ -765,102 +805,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
token_nums_per_expert_cumsum->data<int64_t>(), token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(), token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>() m_indices->data<int>()
); );)
} else if (num_experts_per_rank == 9) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 9><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else if (num_experts_per_rank == 16) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 16><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else if (num_experts_per_rank == 64) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 64><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else if (num_experts_per_rank == 128) {
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
input.data<phi::dtype::float8_e4m3fn>(),
scale.data<float>(),
topk_ids.data<int64_t>(),
topk_weights.data<float>(),
token_nums_per_expert.data<int>(),
token_nums_per_expert_padded.data<int>(),
moe_topk,
num_rows,
token_nums_this_rank,
token_nums_this_rank_padded,
hidden_size,
permute_input->data<phi::dtype::float8_e4m3fn>(),
permute_scale->data<float>(),
permute_indices_per_token->data<int>(),
dst_weights->data<float>(),
dst_indices->data<int>(),
cumsum_idx_gpu->data<int>(),
token_nums_per_expert_cumsum->data<int64_t>(),
token_nums_per_expert_padded_cumsum->data<int64_t>(),
m_indices->data<int>()
);
} else {
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
}
} }

View File

@@ -168,6 +168,8 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
run_align_kernel(64); run_align_kernel(64);
} else if (num_experts == 128) { } else if (num_experts == 128) {
run_align_kernel(128); run_align_kernel(128);
} else if (num_experts == 160) {
run_align_kernel(160);
} else { } else {
PD_THROW("Not support num_experts: %d", num_experts); PD_THROW("Not support num_experts: %d", num_experts);
} }

View File

@@ -134,7 +134,7 @@ class KernelInterface:
*args: positional arguments *args: positional arguments
**kwargs: keyword arguments **kwargs: keyword arguments
""" """
op_name = "haha" + str(kwargs["N"]) op_name = f'haha_N{str(kwargs["N"])}_K{str(kwargs["K"])}'
if op_name in self.func_map.keys(): if op_name in self.func_map.keys():
return self.func_map[op_name](*args) return self.func_map[op_name](*args)