From a36d60aa184c62bafa3558ec3b0729ffdc088cda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:17:37 +0800 Subject: [PATCH] [FIX BUG] fix bug in TP in permute_x_fp8_kernel (#5350) * commit * commit * commit * commit * commit * commit --- custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu index e3180171b..79317afab 100644 --- a/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu +++ b/custom_ops/gpu_ops/moe/ep_moe_expert_dispatch.cu @@ -989,8 +989,20 @@ std::vector EPMoeExpertDispatchFP8( paddle::DataType::FLOAT32, place); - auto m_indices = - GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::INT32, place); + paddle::Tensor m_indices; + if (use_in_ep) { + m_indices = GetEmptyTensor( + {token_nums_feed_to_ffn}, paddle::DataType::INT32, place); + } else { + // Note(ZKK) + // In TP, we must init m_indices with -1, + // because we allocate too much space. + // token_rows * moe_topk + num_experts_per_rank * (128 - 1) + // Later will optimize this. + m_indices = paddle::full( + {token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place); + } + auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place); auto token_nums_per_expert_padded_cumsum =