// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wunused-function" #pragma once #include "moe/fused_moe_helper.h" #include "moe/fused_moe_op.h" #pragma GCC diagnostic pop #include "helper.h" template void MoeDispatchKernel( const paddle::Tensor &input, const paddle::Tensor &gating_output, const paddle::optional &gating_correction_bias, const paddle::optional &w4a8_in_scale, const int moe_topk, const bool group_moe, const bool topk_only_mode, const int num_rows, const int hidden_size, const int expert_num, paddle::Tensor *permute_input, paddle::Tensor *tokens_expert_prefix_sum, paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight, paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) { using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto stream = input.stream(); auto place = input.place(); if (group_moe) { // Check if expert_num is divisible by moe_topk, else throw an error PADDLE_ENFORCE_EQ(expert_num % moe_topk, 0, common::errors::InvalidArgument( "The number of experts (expert_num) " "must be divisible by moe_topk. " "Got expert_num = %d and moe_topk = %d.", expert_num, moe_topk)); } const int num_moe_inputs = AlignTo16(num_rows * moe_topk); const int bytes = num_moe_inputs * sizeof(int); CubKeyValueSorter sorter_; sorter_.update_num_experts(expert_num); const int sorter_ws_size_bytes = AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows)); const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int); paddle::Tensor ws_ptr_tensor = GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size}, paddle::DataType::INT8, place); int8_t *ws_ptr = ws_ptr_tensor.data(); int *source_rows_ = reinterpret_cast(ws_ptr); int8_t *sorter_ws_ptr = reinterpret_cast(ws_ptr + bytes); int *permuted_experts_ = reinterpret_cast(sorter_ws_ptr + sorter_ws_size_bytes); int *permuted_rows_ = permuted_experts_ + num_moe_inputs; int *topk_idx_ptr = topk_idx->data(); float *softmax_max_prob = nullptr; if (group_moe) { paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); // (TODO: check fill success ?) paddle::experimental::fill(softmax_max_prob_tensor, 0.f); softmax_max_prob = softmax_max_prob_tensor.data(); } float *softmax_out_; const bool is_pow_2 = (expert_num != 0) && ((expert_num & (expert_num - 1)) == 0); paddle::Tensor softmax_buffer; if (!is_pow_2 || expert_num > 256 || group_moe || gating_correction_bias) { softmax_buffer = GetEmptyTensor({num_rows * expert_num}, paddle::DataType::FLOAT32, place); softmax_out_ = softmax_buffer.data(); } else { softmax_out_ = nullptr; } topk_gating_softmax_kernelLauncher::run( gating_output.data(), gating_correction_bias ? gating_correction_bias.get().data() : nullptr, topk_weight->data(), softmax_out_, topk_idx_ptr, source_rows_, softmax_max_prob, num_rows, expert_num, moe_topk, group_moe, stream, topk_only_mode); sorter_.run(reinterpret_cast(sorter_ws_ptr), sorter_ws_size_bytes, topk_idx_ptr, expert_idx_per_token->data(), source_rows_, permuted_rows_, moe_topk * num_rows, false, stream); if (w4a8_in_scale) { if (permute_input->dtype() == paddle::DataType::INT8) { initialize_moe_routing_kernelLauncher::run( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), permute_indices_per_token->data(), num_rows, num_rows, hidden_size, moe_topk, stream); } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { initialize_moe_routing_kernelLauncher::run( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), permute_indices_per_token->data(), num_rows, num_rows, hidden_size, moe_topk, stream); } } else { initialize_moe_routing_kernelLauncher::run( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), nullptr, permute_indices_per_token->data(), num_rows, num_rows, hidden_size, moe_topk, stream); } compute_total_rows_before_expert( expert_idx_per_token->data(), moe_topk * num_rows, expert_num, tokens_expert_prefix_sum->data(), stream); } std::vector MoeExpertDispatch( const paddle::Tensor &input, const paddle::Tensor &gating_output, const paddle::optional &gating_correction_bias, const paddle::optional &w4a8_in_scale, const int moe_topk, const bool group_moe, const std::string &moe_quant_type, const bool topk_only_mode) { const auto input_type = input.dtype(); auto place = input.place(); int token_rows = 0; auto input_dims = input.dims(); auto gating_dims = gating_output.dims(); const int expert_num = gating_dims[gating_dims.size() - 1]; if (input_dims.size() == 3) { token_rows = input_dims[0] * input_dims[1]; } else { token_rows = input_dims[0]; } const int num_rows = token_rows; const int hidden_size = input.dims()[input_dims.size() - 1]; auto permute_input_dtype = input_type; if (w4a8_in_scale) { if (moe_quant_type == "w4a8") { permute_input_dtype = paddle::DataType::INT8; } else if (moe_quant_type == "w4afp8") { permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; } } auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size}, permute_input_dtype, place); // correspond to the weighted coefficients of the results from each expert. auto topk_weight = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); auto topk_idx = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT32, place); auto tokens_expert_prefix_sum = GetEmptyTensor({expert_num}, paddle::DataType::INT64, place); auto permute_indices_per_token = GetEmptyTensor({moe_topk, num_rows}, paddle::DataType::INT32, place); auto expert_idx_per_token = GetEmptyTensor({num_rows * moe_topk}, paddle::DataType::INT32, place); switch (input_type) { case paddle::DataType::BFLOAT16: MoeDispatchKernel( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, &topk_weight, &topk_idx, &expert_idx_per_token); break; case paddle::DataType::FLOAT16: MoeDispatchKernel( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, &topk_weight, &topk_idx, &expert_idx_per_token); break; default: PD_THROW("Unsupported data type for MoeDispatchKernel"); } return {permute_input, tokens_expert_prefix_sum, permute_indices_per_token, topk_weight, topk_idx, expert_idx_per_token}; } std::vector> MoeExpertDispatchInferShape( const std::vector &input_shape, const std::vector &gating_output_shape, const paddle::optional> &bias_shape, const int moe_topk) { int token_rows = -1; if (input_shape.size() == 3) { token_rows = input_shape[0] * input_shape[1]; } else { token_rows = input_shape[0]; } const int expert_num = gating_output_shape[gating_output_shape.size() - 1]; const int num_rows = token_rows; const int hidden_size = input_shape[input_shape.size() - 1]; const int permuted_rows = num_rows == -1 ? -1 : moe_topk * num_rows; return {{permuted_rows, hidden_size}, {expert_num}, {moe_topk, num_rows}, {num_rows, moe_topk}, {num_rows, moe_topk}, {permuted_rows}}; } std::vector MoeExpertDispatchInferDtype(const paddle::DataType &input_dtype, const paddle::DataType &gating_output_dtype, const paddle::optional &bias_type, const int moe_topk) { return {input_dtype, paddle::DataType::INT64, paddle::DataType::INT32, paddle::DataType::FLOAT32, paddle::DataType::INT32, paddle::DataType::INT32}; } /** * @brief Mixture of Experts (MoE) Expert Dispatch Operator * * This operator performs the following key functions: * 1. Computes top-k experts for each input token based on gating scores * 2. Permutes input tokens according to their selected experts for efficient expert processing * 3. Computes prefix sums of tokens per expert for group_gemm optimization * * Inputs: * - input: The input tensor to be routed to experts * Shape: [total_tokens, hidden_size] * dtype: bfloat16 or float16 * - gating_output: Gating network output scores for each token-expert pair * Shape: [total_tokens, expert_num] * dtype: must be float32 * - gating_correction_bias: Optional bias term for gating correction (expert_num) * * Outputs: * - permute_input: Permuted input tensor organized by expert * Shape: [moe_topk * total_tokens, hidden_size] * dtype: Same as input * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm * Shape: [expert_num] * dtype: int64 * - permute_indices_per_token: Indices mapping for reconstructing original order * Shape: [moe_topk, total_tokens] * dtype: int32 * - top_k_weight: Weight coefficients for combining expert outputs * Shape: [total_tokens, moe_topk] * dtype: float32 * - top_k_indices: Indices of selected top-k experts for each token * Shape: [total_tokens, moe_topk] * dtype: int32 * * Attributes: * - moe_topk: Number of experts to select for each token (k value in top-k routing) * - group_moe: Whether to perform group softmax within the operator * (true: softmax is computed within groups of experts, * false: standard softmax across all experts) * - topk_only_mode: Operation mode selector * (true: only performs topk selection without softmax, * false: performs full softmax+topk computation) * * Note: * - The operator requires 2D input format [total_tokens, hidden_size] * - For optimal performance, expert_num should be a power of 2 when possible * - When group_moe is true, expert_num must be divisible by moe_topk */ PD_BUILD_STATIC_OP(moe_expert_dispatch) .Inputs({"input", "gating_output", paddle::Optional("gating_correction_bias"), paddle::Optional("w4a8_in_scale")}) .Outputs({"permute_input", "tokens_expert_prefix_sum", "permute_indices_per_token", "topk_weight", "topk_idx", "expert_idx_per_token"}) .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertDispatchInferDtype));