[FIX BUG] fix bug in TP in permute_x_fp8_kernel (#5350)

* commit

* commit

* commit

* commit

* commit

* commit
This commit is contained in:
周周周
2025-12-03 21:17:37 +08:00
committed by GitHub
parent 5f8d4aedea
commit a36d60aa18

View File

@@ -989,8 +989,20 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
paddle::DataType::FLOAT32,
place);
auto m_indices =
GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::INT32, place);
paddle::Tensor m_indices;
if (use_in_ep) {
m_indices = GetEmptyTensor(
{token_nums_feed_to_ffn}, paddle::DataType::INT32, place);
} else {
// Note(ZKK)
// In TP, we must init m_indices with -1,
// because we allocate too much space.
// token_rows * moe_topk + num_experts_per_rank * (128 - 1)
// Later will optimize this.
m_indices = paddle::full(
{token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
}
auto token_nums_per_expert_cumsum =
GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
auto token_nums_per_expert_padded_cumsum =