mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[XPU] merge apply_tp, ops support token_num = 0 (#4507)
This commit is contained in:
@@ -40,20 +40,21 @@ std::vector<paddle::Tensor> MoeEPCombineKernel(
|
||||
|
||||
auto combined_out = paddle::empty(
|
||||
{recv_token_num, hidden_dim}, ffn_out.dtype(), ffn_out.place());
|
||||
|
||||
const float* dequant_score = nullptr;
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
if (recv_token_num > 0) {
|
||||
int ret = infer_ops::moe_ep_ffn_post_fusion(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_T*>(ffn_out.data<T>()),
|
||||
moe_index.data<int32_t>(),
|
||||
reinterpret_cast<const XPU_T*>(weights.data<T>()),
|
||||
dequant_score,
|
||||
reinterpret_cast<XPU_T*>(combined_out.mutable_data<T>()),
|
||||
recv_token_num,
|
||||
hidden_dim,
|
||||
topk,
|
||||
expand_token_num);
|
||||
PD_CHECK(ret == 0);
|
||||
}
|
||||
|
||||
return {combined_out};
|
||||
}
|
||||
|
||||
@@ -60,44 +60,48 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
|
||||
if (std::is_same<TY, int8_t>::value) {
|
||||
permute_input =
|
||||
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
if (token_nums_this_rank > 0) {
|
||||
auto ret = infer_ops::moe_ffn_pre_sorted_quant_pe<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
input_scales.get_ptr()->data<float>(),
|
||||
nullptr,
|
||||
reinterpret_cast<int8_t*>(permute_input.data<int8_t>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
expand_input_scales.data<float>(),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
} else {
|
||||
permute_input = paddle::empty({token_nums_this_rank, n}, input_type, place);
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
if (token_nums_this_rank > 0) {
|
||||
auto ret = infer_ops::moe_ep_ffn_pre_sorted<XPU_TX, int>(
|
||||
xpu_ctx->x_context(),
|
||||
reinterpret_cast<const XPU_TX*>(input.data<TX>()),
|
||||
topk_ids.data<int>(),
|
||||
nullptr,
|
||||
reinterpret_cast<XPU_TX*>(permute_input.data<TX>()),
|
||||
const_cast<int*>(permute_indices_per_token.data<int>()),
|
||||
const_cast<int*>(expert_m.data<int>()),
|
||||
const_cast<int*>(recv_num_tokens_per_expert_list_cumsum.data<int>()),
|
||||
m,
|
||||
n,
|
||||
expert_num,
|
||||
topk,
|
||||
block_num,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
token_nums_this_rank);
|
||||
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
|
||||
}
|
||||
}
|
||||
return {permute_input,
|
||||
permute_indices_per_token,
|
||||
|
||||
@@ -441,6 +441,12 @@ std::vector<paddle::Tensor> MoeExpertFFN(
|
||||
const std::string& quant_method,
|
||||
const int hadamard_blocksize,
|
||||
const int valid_token_num) {
|
||||
if (ffn_in.numel() == 0) {
|
||||
paddle::Tensor ffn2_out =
|
||||
paddle::empty_like(ffn_in, paddle::DataType::BFLOAT16);
|
||||
return {ffn2_out};
|
||||
}
|
||||
|
||||
const auto x_type = ffn_in.dtype();
|
||||
const auto w_type = ffn1_weight.dtype();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user