// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "cutlass/numeric_conversion.h" #include "cutlass_kernels/w4a8_moe/cutlass_extensions/epilogue/epilogue_quant_helper.h" #include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h" #include "group_swiglu_with_masked.h" #include "helper.h" #include "moe/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" #include "w4afp8_gemm/w4afp8_gemm.h" template void MoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, paddle::Tensor ffn_out, bool used_in_ep_low_latency, const int estimate_total_token_nums) { using namespace phi; typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; auto quant_mode = cutlass::epilogue::QuantMode::PerChannelQuant; auto ffn_out_data = ffn_out.data(); auto place = permute_input.place(); auto stream = permute_input.stream(); auto fp16_moe_gemm_runner = MoeGemmRunner>(); auto int8_moe_gemm_runner = MoeGemmRunner>(); auto int4_moe_gemm_runner = MoeGemmRunner>(); auto w4a8_moe_gemm_runner = W4A8MoeGemmRunner(); assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2); const int num_experts = up_gate_proj_weight.dims()[0]; const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1]; assert(up_gate_proj_weight.dims().size() == 3); int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size; constexpr size_t workspace_size = 1 * 1024 * 1024 * 1024; // for nf4 stream-k Allocator* allocator = paddle::GetAllocator(place); Allocator::AllocationPtr workspace; if (quant_method == "weight_only_int4" || quant_method == "w4a8" || quant_method == "w4afp8") { inter_dim = inter_dim * 2; } if (quant_method == "w4a8" || quant_method == "w4afp8") { workspace = allocator->Allocate( SizeOf(paddle::DataType::INT8) * workspace_size); } const int64_t inter_size = inter_dim; typedef PDTraits traits_fp8; typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; int num_experts_ = num_experts; int num_max_tokens_per_expert = 256; int expanded_active_expert_rows; paddle::Tensor fc1_out_tensor; if (permute_input.dims().size() == 3) { num_experts_ = permute_input.dims()[0]; assert(num_experts == num_experts_); num_max_tokens_per_expert = permute_input.dims()[1]; expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert; fc1_out_tensor = GetEmptyTensor( {num_experts_, num_max_tokens_per_expert, inter_size}, T, place); } else { expanded_active_expert_rows = permute_input.dims()[0]; fc1_out_tensor = GetEmptyTensor( {expanded_active_expert_rows, inter_size}, T, place); } auto fc1_out = fc1_out_tensor.data(); using NvType = typename traits_::DataType; auto fc1_expert_biases = up_gate_proj_bias ? const_cast(up_gate_proj_bias.get_ptr())->data() : nullptr; // This is a trick. // expanded_active_expert_rows is not needed in variable group gemm. // but is needed in accommodating deepep low latency mode const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1; // When we tune the optimal configuration, we need the actual total_rows. const int64_t tune_total_rows = expanded_active_expert_rows; if (quant_method == "weight_only_int8") { typename cutlass::WintQuantTraits::Arguments quant_args; int8_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), reinterpret_cast( const_cast(up_gate_proj_scale.get_ptr()) ->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, inter_size, hidden_size, num_experts, quant_args, "none", stream); } else if (quant_method == "weight_only_int4") { typename cutlass::WintQuantTraits::Arguments quant_args; int4_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), reinterpret_cast( up_gate_proj_weight.data()), reinterpret_cast( const_cast(up_gate_proj_scale.get_ptr()) ->data()), reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, inter_size, hidden_size, num_experts, quant_args, "none", stream); } else if (quant_method == "w4a8") { w4a8_moe_gemm_runner.moe_gemm( reinterpret_cast(permute_input.data()), reinterpret_cast( up_gate_proj_weight.data()), quant_mode, reinterpret_cast( const_cast(up_gate_proj_scale.get_ptr()) ->data()), nullptr, // up_gate_proj_scale_dyquant nullptr, // nf4_look_up_table reinterpret_cast(fc1_out), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, inter_size, hidden_size, reinterpret_cast(workspace->ptr()), workspace_size, num_experts, stream); } else if (quant_method == "w4afp8") { typedef PDTraits traits_fp8; typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; Allocator::AllocationPtr ffn1_input_row_sum; ffn1_input_row_sum = allocator->Allocate( sizeof(float) * expanded_active_expert_rows); compute_row_sum( permute_input.data(), expanded_active_expert_rows, hidden_size, reinterpret_cast(ffn1_input_row_sum->ptr()), const_cast(tokens_expert_prefix_sum.data()), num_max_tokens_per_expert, used_in_ep_low_latency, stream); float* row_scale = nullptr; DisPatchW4AFp8GemmWrapper( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), reinterpret_cast(ffn1_input_row_sum->ptr()), row_scale, const_cast(up_gate_proj_scale.get_ptr()) ->data(), reinterpret_cast(fc1_out), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0], num_experts, inter_size, hidden_size, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; fp16_moe_gemm_runner.moe_gemm_bias_act( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), nullptr, reinterpret_cast(fc1_expert_biases), reinterpret_cast(fc1_out), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, inter_size, hidden_size, num_experts, quant_args, "none", stream); } paddle::Tensor act_out_tensor; if (used_in_ep_low_latency) { act_out_tensor = GroupSwigluWithMasked(fc1_out_tensor, tokens_expert_prefix_sum); } else { act_out_tensor = paddle::experimental::swiglu(fc1_out_tensor, nullptr); } auto act_out = act_out_tensor.data(); if (quant_method == "weight_only_int8") { typename cutlass::WintQuantTraits::Arguments quant_args; int8_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast(down_proj_weight.data()), reinterpret_cast( const_cast(down_proj_scale.get_ptr()) ->data()), reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, hidden_size, inter_size / 2, num_experts, quant_args, stream); } else if (quant_method == "weight_only_int4") { typename cutlass::WintQuantTraits::Arguments quant_args; int4_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast( down_proj_weight.data()), reinterpret_cast( const_cast(down_proj_scale.get_ptr()) ->data()), reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, hidden_size, inter_size / 2, num_experts, quant_args, stream); } else if (quant_method == "w4a8") { data_t *down_proj_shift = nullptr; data_t *down_proj_smooth = nullptr; Allocator::AllocationPtr int8_act_out; int8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); MoeFastHardamardWrapper( act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, const_cast(tokens_expert_prefix_sum.data()), down_proj_shift, // down_proj_shift->data(), down_proj_smooth, // down_proj_smooth->data(), down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, 1, 127.0, -127.0, expanded_active_expert_rows, inter_size / 2, num_max_tokens_per_expert, used_in_ep_low_latency, reinterpret_cast(int8_act_out->ptr()), stream ); w4a8_moe_gemm_runner.moe_gemm( reinterpret_cast(int8_act_out->ptr()), reinterpret_cast( down_proj_weight.data()), quant_mode, reinterpret_cast( const_cast(down_proj_scale.get_ptr()) ->data()), nullptr, // down_proj_scale_dyquant nullptr, // reinterpret_cast(d_nf4_look_up_table), // nf4_look_up_table reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, used_in_ep_low_latency ? estimate_total_token_nums : tune_total_rows, hidden_size, inter_size / 2, reinterpret_cast(workspace->ptr()), workspace_size, num_experts, stream); } else if (quant_method == "w4afp8") { data_t *ffn2_shift = nullptr; data_t *ffn2_smooth = nullptr; float* row_scale = nullptr; Allocator::AllocationPtr fp8_act_out; fp8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); Allocator::AllocationPtr ffn2_input_row_sum; ffn2_input_row_sum = allocator->Allocate( sizeof(float) * expanded_active_expert_rows); // note(yuanxiaolan): optimize this MoeFastHardamardWrapper( act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, const_cast(tokens_expert_prefix_sum.data()), ffn2_shift, // ffn2_shift->data(), ffn2_smooth, // ffn2_smooth->data(), nullptr, 1, 448.0f, -448.0f, expanded_active_expert_rows, inter_size / 2, num_max_tokens_per_expert, used_in_ep_low_latency, act_out_tensor.data(), stream ); quantize_moe_input(act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, 448.0f, -448.0f, expanded_active_expert_rows, inter_size / 2, reinterpret_cast(ffn2_input_row_sum->ptr()), const_cast(tokens_expert_prefix_sum.data()), num_max_tokens_per_expert, used_in_ep_low_latency, reinterpret_cast(fp8_act_out->ptr()), stream ); DisPatchW4AFp8GemmWrapper( reinterpret_cast(fp8_act_out->ptr()), reinterpret_cast(down_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), reinterpret_cast(ffn2_input_row_sum->ptr()), row_scale, const_cast(down_proj_scale.get_ptr()) ->data(), reinterpret_cast(ffn_out_data), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0], num_experts, hidden_size, inter_size / 2, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; fp16_moe_gemm_runner.moe_gemm( reinterpret_cast(act_out), reinterpret_cast(down_proj_weight.data()), nullptr, reinterpret_cast(ffn_out_data), const_cast(tokens_expert_prefix_sum.data()), total_rows_in_ll_else_minus1, tune_total_rows, hidden_size, inter_size / 2, num_experts, quant_args, stream); } } paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, const int estimate_total_token_nums) { const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() : (quant_method == "w4afp8") ? paddle::DataType::BFLOAT16 : permute_input.dtype(); auto ffn_out = paddle::empty_like(permute_input, t_type); switch (t_type) { case paddle::DataType::BFLOAT16: MoeFFNKernel(permute_input, tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, down_proj_in_scale, expert_idx_per_token, quant_method, ffn_out, used_in_ep_low_latency, estimate_total_token_nums); break; case paddle::DataType::FLOAT16: MoeFFNKernel(permute_input, tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, down_proj_in_scale, expert_idx_per_token, quant_method, ffn_out, used_in_ep_low_latency, estimate_total_token_nums); break; default: PD_THROW("Unsupported data type for MoeExpertFFN"); } return ffn_out; } std::vector MoeExpertFFN( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, const paddle::optional& down_proj_in_scale, const paddle::optional& expert_idx_per_token, const std::string& quant_method, const bool used_in_ep_low_latency, const int estimate_total_token_nums) { return {MoeExpertFFNFunc(permute_input, tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, down_proj_in_scale, expert_idx_per_token, quant_method, used_in_ep_low_latency, estimate_total_token_nums)}; } std::vector> MoeExpertFFNInferShape( const std::vector& permute_input_shape, const std::vector& tokens_expert_prefix_sum_shape, const std::vector& up_gate_proj_weight_shape, const std::vector& down_proj_weight_shape, const paddle::optional>& up_gate_proj_bias_shape, const paddle::optional>& up_gate_proj_scale_shape, const paddle::optional>& down_proj_scale_shape, const paddle::optional>& down_proj_in_scale_shape, const paddle::optional>& expert_idx_per_token_shape, const std::string& quant_method, const bool used_in_ep_low_latency, const int estimate_total_token_nums) { return {permute_input_shape}; } std::vector MoeExpertFFNInferDtype( const paddle::DataType &permute_input_dtype, const paddle::DataType &tokens_expert_prefix_sum_dtype, const paddle::DataType &up_gate_proj_weight_dtype, const paddle::DataType &down_proj_weight_dtype, const paddle::optional &up_gate_proj_bias_dtype, const paddle::optional &up_gate_proj_scale_dtype, const paddle::optional &down_proj_scale_dtype, const paddle::optional &down_proj_in_scale_dtype, const std::string &quant_method, const bool used_in_ep_low_latency, const int estimate_total_token_nums) { if (quant_method == "w4a8" || quant_method == "w4afp8") { return {up_gate_proj_scale_dtype.get()}; } else { return {permute_input_dtype}; } } /** * @brief Mixture of Experts (MoE) Feed-Forward Network Operator * * This operator performs the expert computation in MoE architecture, including: * 1. First linear transformation (up_gate_proj) with optional quantization * 2. SwiGLU activation function * 3. Second linear transformation (down_proj) with optional quantization * * Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization. * * Inputs: * - permute_input: Permuted input tensor organized by expert * Shape: [total_tokens * top_k, hidden_size] * dtype: bfloat16/float16 (or int8 for w4a8) * - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm * Shape: [num_experts] * dtype: int64 * - up_gate_proj_weight: First FFN layer weights * Shape: [num_experts, inter_size * 2, hidden_size] * dtype: Same as input (unquantized) or int8 (quantized) * - down_proj_weight: Second FFN layer weights * Shape: [num_experts, hidden_size, inter_size] * dtype: Same as input (unquantized) or int8 (quantized) * - up_gate_proj_bias: Optional bias for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input * - up_gate_proj_scale: Quantization scales for first FFN layer * Shape: [num_experts, inter_size * 2] * dtype: Same as input * - down_proj_scale: Quantization scales for second FFN layer * Shape: [num_experts, hidden_size] * dtype: Same as input * - down_proj_in_scale: Optional input scales for second FFN layer (w4a8 only) * dtype: float32 * - expert_idx_per_token: Optional expert indices per token (w4a8 only) * Shape: [total_tokens] * dtype: int64 * * Outputs: * - output_tensor: Output tensor after MoE FFN computation * Shape: Same as permute_input * dtype: Same as input (or up_gate_proj_scale dtype for w4a8) * * Attributes: * - quant_method: Quantization method to use * Options: "none", "weight_only_int4", "weight_only_int8", "w4a8" * - used_in_ep_low_latency: Whether running in low latency mode * Affects activation function implementation * * Note: * - w4a8 mode requires additional workspace memory allocation * - Low latency mode uses specialized grouped SwiGLU implementation */ PD_BUILD_STATIC_OP(moe_expert_ffn) .Inputs({"permute_input", "tokens_expert_prefix_sum", "up_gate_proj_weight", "down_proj_weight", paddle::Optional("up_gate_proj_bias"), paddle::Optional("up_gate_proj_scale"), paddle::Optional("down_proj_scale"), paddle::Optional("down_proj_in_scale"), paddle::Optional("expert_idx_per_token")}) .Outputs({"output_tensor"}) .Attrs({"quant_method:std::string", "used_in_ep_low_latency:bool", "estimate_total_token_nums:int"}) .SetKernelFn(PD_KERNEL(MoeExpertFFN)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNInferDtype));