mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-28 05:12:24 +08:00
moe preprocess op support 160 experts and fused_moe triton kernel name add K (#3121)
This commit is contained in:
@@ -25,6 +25,46 @@
|
|||||||
|
|
||||||
#include "helper.h"
|
#include "helper.h"
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
|
|
||||||
|
#define DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK, ...) \
|
||||||
|
switch (num_experts_per_rank) { \
|
||||||
|
case 8: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 8; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 9: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 9; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 16: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 16; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 64: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 128: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 128; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case 160: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 160; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: { \
|
||||||
|
std::ostringstream err_msg; \
|
||||||
|
err_msg << "Unsupported num_experts_per_rank: " << num_experts_per_rank; \
|
||||||
|
throw std::invalid_argument(err_msg.str()); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
namespace cg = cooperative_groups;
|
namespace cg = cooperative_groups;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@@ -743,8 +783,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
|||||||
auto place = input.place();
|
auto place = input.place();
|
||||||
// const int gridx = min(132 * 8, num_rows);
|
// const int gridx = min(132 * 8, num_rows);
|
||||||
const int gridx = 132 * 8;
|
const int gridx = 132 * 8;
|
||||||
if (num_experts_per_rank == 8) {
|
DISPATCH_NUM_EXPERTS_PER_RANK(num_experts_per_rank, NUM_EXPERTS_PER_RANK,
|
||||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 8><<<gridx, 512, 0, stream>>>(
|
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, NUM_EXPERTS_PER_RANK><<<gridx, 512, 0, stream>>>(
|
||||||
input.data<phi::dtype::float8_e4m3fn>(),
|
input.data<phi::dtype::float8_e4m3fn>(),
|
||||||
scale.data<float>(),
|
scale.data<float>(),
|
||||||
topk_ids.data<int64_t>(),
|
topk_ids.data<int64_t>(),
|
||||||
@@ -765,102 +805,8 @@ void EPMoeDispatchFP8Kernel(const paddle::Tensor& input,
|
|||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
token_nums_per_expert_cumsum->data<int64_t>(),
|
||||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
||||||
m_indices->data<int>()
|
m_indices->data<int>()
|
||||||
);
|
);)
|
||||||
} else if (num_experts_per_rank == 9) {
|
|
||||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 9><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
scale.data<float>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
token_nums_per_expert_padded.data<int>(),
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
token_nums_this_rank_padded,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
permute_scale->data<float>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
|
||||||
m_indices->data<int>()
|
|
||||||
);
|
|
||||||
} else if (num_experts_per_rank == 16) {
|
|
||||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 16><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
scale.data<float>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
token_nums_per_expert_padded.data<int>(),
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
token_nums_this_rank_padded,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
permute_scale->data<float>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
|
||||||
m_indices->data<int>()
|
|
||||||
);
|
|
||||||
} else if (num_experts_per_rank == 64) {
|
|
||||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 64><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
scale.data<float>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
token_nums_per_expert_padded.data<int>(),
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
token_nums_this_rank_padded,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
permute_scale->data<float>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
|
||||||
m_indices->data<int>()
|
|
||||||
);
|
|
||||||
} else if (num_experts_per_rank == 128) {
|
|
||||||
permute_x_fp8_kernel<phi::dtype::float8_e4m3fn, 128><<<gridx, 512, 0, stream>>>(
|
|
||||||
input.data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
scale.data<float>(),
|
|
||||||
topk_ids.data<int64_t>(),
|
|
||||||
topk_weights.data<float>(),
|
|
||||||
token_nums_per_expert.data<int>(),
|
|
||||||
token_nums_per_expert_padded.data<int>(),
|
|
||||||
moe_topk,
|
|
||||||
num_rows,
|
|
||||||
token_nums_this_rank,
|
|
||||||
token_nums_this_rank_padded,
|
|
||||||
hidden_size,
|
|
||||||
permute_input->data<phi::dtype::float8_e4m3fn>(),
|
|
||||||
permute_scale->data<float>(),
|
|
||||||
permute_indices_per_token->data<int>(),
|
|
||||||
dst_weights->data<float>(),
|
|
||||||
dst_indices->data<int>(),
|
|
||||||
cumsum_idx_gpu->data<int>(),
|
|
||||||
token_nums_per_expert_cumsum->data<int64_t>(),
|
|
||||||
token_nums_per_expert_padded_cumsum->data<int64_t>(),
|
|
||||||
m_indices->data<int>()
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
PD_THROW("Not dispatching this num_experts_per_rank(", num_experts_per_rank, ") for EPMoeDispatchFP8Kernel");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -168,6 +168,8 @@ std::vector<paddle::Tensor> tritonmoe_preprocess_kernel(const paddle::Tensor& to
|
|||||||
run_align_kernel(64);
|
run_align_kernel(64);
|
||||||
} else if (num_experts == 128) {
|
} else if (num_experts == 128) {
|
||||||
run_align_kernel(128);
|
run_align_kernel(128);
|
||||||
|
} else if (num_experts == 160) {
|
||||||
|
run_align_kernel(160);
|
||||||
} else {
|
} else {
|
||||||
PD_THROW("Not support num_experts: %d", num_experts);
|
PD_THROW("Not support num_experts: %d", num_experts);
|
||||||
}
|
}
|
||||||
|
@@ -134,7 +134,7 @@ class KernelInterface:
|
|||||||
*args: positional arguments
|
*args: positional arguments
|
||||||
**kwargs: keyword arguments
|
**kwargs: keyword arguments
|
||||||
"""
|
"""
|
||||||
op_name = "haha" + str(kwargs["N"])
|
op_name = f'haha_N{str(kwargs["N"])}_K{str(kwargs["K"])}'
|
||||||
if op_name in self.func_map.keys():
|
if op_name in self.func_map.keys():
|
||||||
return self.func_map[op_name](*args)
|
return self.func_map[op_name](*args)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user