mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -211,12 +211,14 @@ std::vector<std::vector<int64_t>> MoeExpertDispatchInferShape(
|
||||
const int expert_num = gating_output_shape[gating_output_shape.size() - 1];
|
||||
const int num_rows = token_rows;
|
||||
const int hidden_size = input_shape[input_shape.size() - 1];
|
||||
const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows;
|
||||
|
||||
return {{moe_topk * num_rows, hidden_size},
|
||||
return {{permuted_rows, hidden_size},
|
||||
{expert_num},
|
||||
{moe_topk, num_rows},
|
||||
{num_rows, moe_topk},
|
||||
{num_rows, moe_topk}};
|
||||
{num_rows, moe_topk},
|
||||
{permuted_rows}};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType>
|
||||
@@ -225,7 +227,7 @@ MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype,
|
||||
const paddle::optional<paddle::DataType> &bias_type,
|
||||
const int moe_topk) {
|
||||
return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32,
|
||||
paddle::DataType::FLOAT32, paddle::DataType::INT32};
|
||||
paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -281,7 +283,8 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch)
|
||||
paddle::Optional("gating_correction_bias"),
|
||||
paddle::Optional("w4a8_in_scale")})
|
||||
.Outputs({"permute_input", "tokens_expert_prefix_sum",
|
||||
"permute_indices_per_token", "topk_weight", "topk_idx"})
|
||||
"permute_indices_per_token", "topk_weight", "topk_idx",
|
||||
"expert_idx_per_token"})
|
||||
.Attrs({"moe_topk:int", "group_moe:bool", "topk_only_mode:bool"})
|
||||
.SetKernelFn(PD_KERNEL(MoeExpertDispatch))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape))
|
||||
|
||||
Reference in New Issue
Block a user