[XPU] merge apply_tp, ops support token_num = 0 (#4507)

This commit is contained in:
zhupengyang
2025-10-23 19:09:58 +08:00
committed by GitHub
parent 4ffe41a747
commit 3a43dbf82d
4 changed files with 191 additions and 254 deletions

View File

@@ -40,20 +40,21 @@ std::vector<paddle::Tensor> MoeEPCombineKernel(
auto combined_out = paddle::empty(
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
const float* dequant_score = nullptr;
int ret = infer_ops::moe_ep_ffn_post_fusion(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
moe_index.data<int32_t>(),
reinterpret_cast<const XPU_T*>(weights.data<T>()),
dequant_score,
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
recv_token_num,
hidden_dim,
topk,
expand_token_num);
PD_CHECK(ret == 0);
if (recv_token_num > 0) {
int ret = infer_ops::moe_ep_ffn_post_fusion(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
moe_index.data<int32_t>(),
reinterpret_cast<const XPU_T*>(weights.data<T>()),
dequant_score,
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
recv_token_num,
hidden_dim,
topk,
expand_token_num);
PD_CHECK(ret == 0);
}
return {combined_out};
}

View File

@@ -60,44 +60,48 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
if (std::is_same<TY, int8_t>::value) {
permute_input =
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
input_scales.get_ptr()->data<float>(),
nullptr,
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
expand_input_scales.data<float>(),
m,
n,
expert_num,
topk,
block_num,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
if (token_nums_this_rank > 0) {
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
input_scales.get_ptr()->data<float>(),
nullptr,
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
expand_input_scales.data<float>(),
m,
n,
expert_num,
topk,
block_num,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
} else {
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
nullptr,
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
m,
n,
expert_num,
topk,
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
if (token_nums_this_rank > 0) {
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
topk_ids.data<int>(),
nullptr,
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
const_cast<int*>(permute_indices_per_token.data<int>()),
const_cast<int*>(expert_m.data<int>()),
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
m,
n,
expert_num,
topk,
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
}
return {permute_input,
permute_indices_per_token,

View File

@@ -441,6 +441,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
const std::string& quant_method,
const int hadamard_blocksize,
const int valid_token_num) {
if (ffn_in.numel() == 0) {
paddle::Tensor ffn2_out =
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
return {ffn2_out};
}
const auto x_type = ffn_in.dtype();
const auto w_type = ffn1_weight.dtype();