[Metax] adapt cutlass moe for ernie-vl (#4685)

This commit is contained in:
Neil Zhu
2025-11-03 17:44:27 +08:00
committed by GitHub
parent 69c2f3cda1
commit c95d0740ec
6 changed files with 174 additions and 101 deletions

View File

@@ -101,6 +101,10 @@ std::vector<paddle::Tensor> FusedExpertMoe(
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
if (output.dims()[0] == 0) {
return {output};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16,

View File

@@ -178,6 +178,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
auto permute_indices_per_token =
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
if (token_rows == 0) {
return {permute_input,
tokens_expert_prefix_sum,
permute_indices_per_token,
top_k_weight,
top_k_indices};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,

View File

@@ -114,6 +114,10 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const auto input_type = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input);
if (permute_input.numel() == 0) {
return {ffn_out};
}
switch (input_type) {
case paddle::DataType::BFLOAT16:
McMoeFFNKernel<paddle::DataType::BFLOAT16,