mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] adapt cutlass moe for ernie-vl (#4685)
This commit is contained in:
@@ -101,6 +101,10 @@ std::vector<paddle::Tensor> FusedExpertMoe(
|
||||
const auto input_type = input.dtype();
|
||||
auto output = paddle::empty_like(input);
|
||||
|
||||
if (output.dims()[0] == 0) {
|
||||
return {output};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
FusedMoeKernel<paddle::DataType::BFLOAT16,
|
||||
|
||||
@@ -178,6 +178,14 @@ std::vector<paddle::Tensor> MoeExpertDispatch(
|
||||
auto permute_indices_per_token =
|
||||
GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place);
|
||||
|
||||
if (token_rows == 0) {
|
||||
return {permute_input,
|
||||
tokens_expert_prefix_sum,
|
||||
permute_indices_per_token,
|
||||
top_k_weight,
|
||||
top_k_indices};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
MoeDispatchKernel<paddle::DataType::BFLOAT16>(input,
|
||||
|
||||
@@ -114,6 +114,10 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const auto input_type = permute_input.dtype();
|
||||
auto ffn_out = paddle::empty_like(permute_input);
|
||||
|
||||
if (permute_input.numel() == 0) {
|
||||
return {ffn_out};
|
||||
}
|
||||
|
||||
switch (input_type) {
|
||||
case paddle::DataType::BFLOAT16:
|
||||
McMoeFFNKernel<paddle::DataType::BFLOAT16,
|
||||
|
||||
Reference in New Issue
Block a user