mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix ep prefill (#2762)
This commit is contained in:
@@ -158,7 +158,8 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor &input, const paddle::Tensor &scale,
|
||||
const paddle::Tensor &topk_ids, const paddle::Tensor &topk_weights,
|
||||
const paddle::Tensor &token_nums_per_expert,
|
||||
const paddle::Tensor &token_nums_per_expert_padded);
|
||||
const paddle::Tensor &token_nums_per_expert_padded,
|
||||
const bool use_in_ep, const int token_nums_this_rank_padded);
|
||||
|
||||
std::vector<paddle::Tensor> PerTokenQuant(paddle::Tensor &input,
|
||||
const int block_size);
|
||||
|
||||
@@ -870,7 +870,9 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const paddle::Tensor& topk_ids,
|
||||
const paddle::Tensor& topk_weights,
|
||||
const paddle::Tensor& num_experts_per_rank_tensor,
|
||||
const paddle::Tensor& num_experts_per_rank_padded_tensor) {
|
||||
const paddle::Tensor& num_experts_per_rank_padded_tensor,
|
||||
const bool use_in_ep,
|
||||
const int token_nums_this_rank_padded) {
|
||||
const auto input_type = input.dtype();
|
||||
const int moe_topk = topk_ids.dims()[1];
|
||||
auto place = input.place();
|
||||
@@ -886,22 +888,21 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchFP8(
|
||||
const int hidden_size = input.dims()[input_dims.size() - 1];
|
||||
const int num_experts_per_rank = num_experts_per_rank_tensor.dims()[0];
|
||||
|
||||
int32_t token_nums_this_rank_padded = token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||
// token_nums_this_rank_padded = token_nums_this_rank_padded_useless;
|
||||
int32_t token_nums_feed_to_ffn = use_in_ep ? token_nums_this_rank_padded : token_rows * moe_topk + num_experts_per_rank * (128-1);
|
||||
|
||||
auto permute_input = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded, hidden_size},
|
||||
{token_nums_feed_to_ffn, hidden_size},
|
||||
input_type,
|
||||
place);
|
||||
auto permute_scale = GetEmptyTensor(
|
||||
{token_nums_this_rank_padded, hidden_size / 128},
|
||||
{token_nums_feed_to_ffn, hidden_size / 128},
|
||||
paddle::DataType::FLOAT32,
|
||||
place);
|
||||
|
||||
auto m_indices = paddle::full({token_nums_this_rank_padded}, -1, paddle::DataType::INT32, place);
|
||||
auto m_indices = paddle::full({token_nums_feed_to_ffn}, -1, paddle::DataType::INT32, place);
|
||||
auto token_nums_per_expert_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto token_nums_per_expert_padded_cumsum = GetEmptyTensor({num_experts_per_rank}, paddle::DataType::INT64, place);
|
||||
auto dst_weights = GetEmptyTensor({token_nums_this_rank_padded}, paddle::DataType::FLOAT32, place);
|
||||
auto dst_weights = GetEmptyTensor({token_nums_feed_to_ffn}, paddle::DataType::FLOAT32, place);
|
||||
auto dst_indices = GetEmptyTensor({num_rows, num_experts_per_rank}, paddle::DataType::INT32, place);
|
||||
auto permute_indices_per_token = paddle::full({num_experts_per_rank, num_rows}, -1, paddle::DataType::INT32, place);
|
||||
auto cumsum_idx_gpu = paddle::full({num_experts_per_rank}, 0, paddle::DataType::INT32, place);
|
||||
@@ -949,4 +950,5 @@ PD_BUILD_STATIC_OP(ep_moe_expert_dispatch_fp8)
|
||||
"dst_indices",
|
||||
"cumsum_idx_gpu",
|
||||
"m_indices"})
|
||||
.Attrs({"use_in_ep:bool", "token_nums_this_rank_padded:int"})
|
||||
.SetKernelFn(PD_KERNEL(EPMoeExpertDispatchFP8));
|
||||
|
||||
Reference in New Issue
Block a user