mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
moe preprocess op support 160 experts and fused_moe triton kernel name add K (#3121)
This commit is contained in:
@@ -25,6 +25,46 @@
|
||||
|
||||
#include "helper.h"
|
||||
#include <cooperative_groups.h>
|
||||
|
||||
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
|
||||
switch (num_experts_per_rank) { \
|
||||
case 8: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 9: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 9; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 16: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 16; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 64: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 128: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 128; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
case 160: { \
|
||||
constexpr size_t NUM_EXPERTS_PER_RANK = 160; \
|
||||
__VA_ARGS__ \
|
||||
break; \
|
||||
} \
|
||||
default: { \
|
||||
std::ostringstream err_msg; \
|
||||
err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \
|
||||
throw std::invalid_argument(err_msg.str()); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template<typename T>
|
||||
@@ -743,8 +783,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
auto place = input.place();
|
||||
// const int gridx = min(132 * 8, num_rows);
|
||||
const int gridx = 132 * 8;
|
||||
if (num_experts_per_rank == 8) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
|
||||
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
@@ -765,102 +805,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 9) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 9><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 16) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 16><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 64) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 64><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else if (num_experts_per_rank == 128) {
|
||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
|
||||
input.data<phi::dtype::float8_e4m3fn>(),
|
||||
scale.data<float>(),
|
||||
topk_ids.data<int64_t>(),
|
||||
topk_weights.data<float>(),
|
||||
token_nums_per_expert.data<int>(),
|
||||
token_nums_per_expert_padded.data<int>(),
|
||||
moe_topk,
|
||||
num_rows,
|
||||
token_nums_this_rank,
|
||||
token_nums_this_rank_padded,
|
||||
hidden_size,
|
||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
||||
permute_scale->data<float>(),
|
||||
permute_indices_per_token->data<int>(),
|
||||
dst_weights->data<float>(),
|
||||
dst_indices->data<int>(),
|
||||
cumsum_idx_gpu->data<int>(),
|
||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||
m_indices->data<int>()
|
||||
);
|
||||
} else {
|
||||
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
|
||||
}
|
||||
);)
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@@ -168,6 +168,8 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
|
||||
run_align_kernel(64);
|
||||
} else if (num_experts == 128) {
|
||||
run_align_kernel(128);
|
||||
} else if (num_experts == 160) {
|
||||
run_align_kernel(160);
|
||||
} else {
|
||||
PD_THROW("Not support num_experts: %d", num_experts);
|
||||
}
|
||||
|
@@ -134,7 +134,7 @@ class KernelInterface:
|
||||
*args: positional arguments
|
||||
**kwargs: keyword arguments
|
||||
"""
|
||||
op_name = "haha" + str(kwargs["N"])
|
||||
op_name = f'haha_N{str(kwargs["N"])}_K{str(kwargs["K"])}'
|
||||
if op_name in self.func_map.keys():
|
||||
return self.func_map[op_name](*args)
|
||||
|
||||
|
Reference in New Issue
Block a user