Files
FastDeploy/custom_ops/gpu_ops/moe/moe_ffn_wint2.cu
Yiqun Liu 8f426c1690 Optimize the performance of moe_expert_ffn_wint2 (#2990)
* Change wint2 to ColumnMajor.

Change-Id: I6b44d02946a685f8fe24d9f2c7be258b51e16da2

* Unify default_wint2x_mma.

Change-Id: I9e77b0e8e6cecab01fedc0b24b536ee0a1a89ff7

* Change wint2 to ColumnMajorTileInterleave.

Change-Id: I593cbe36f991c0c5044989d65f0014087587c624

* Enable async copy for B.

Change-Id: Ia3ac37ad162a8cf3ccce4f268e81bd06c8ac3c46

* Add wint2x Dequantizer

* Remove TileDequanterB related codes.

Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8

* Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.

Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca

* Implement Wint2ParamsAccessor to load extra quant params from global memory.

Change-Id: Ic3750cd9b767df8893501820880c3342a4b47233

* Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.

Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca

* Use async copy for local_scale.

Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893

* Check and correct the load and dequantize of weights.

Change-Id: Ie8dca505b39987144964fe6407d465b3b5953790

* Change for performance tuning.

Change-Id: I1da026fb1d1533a9d70350c7ba23c27e896cfc29

* Optimize the global memory access size of local_scale reading.

Change-Id: I4cbe3a2ef5951723d415c2d3252ce912394beaf5

* Specialize mma_tensor_op for wint2 to enable fine-grained pipeline.

Change-Id: Icbb4d48f90a41136f42d6ffff42d68de32f408da

* Minor fix.

Change-Id: I14d4ac9d267ee05442a3b47f00c26bee13d79e6f

* optimizing dequant performance with LOP3

* optimizing dequant performance with LOP3

* Avoid redundant dequantization of local_scale and use bf16 as computing type.

Change-Id: I63239ebc8f8e4a92d6281af59840ba50600b4334

* Add Multiplier and remove some logs.

Change-Id: Ifa199d81e6aeb472d2247c63f85ef30213684bcd

* optimizing dequant performance with LOP3

* Use __byte_perm to implement int8 to float32 conversion for performance improvement.

* Use lop3 to optimize the dequantize of local_scale.

Change-Id: I6189759970cb5b8dcbef769724784b8a7533b63c

* Minor fix and remove some logs.

Change-Id: I6279ba9926d5041093b1c6aea200acf2e4c49d46

* Fix stages for test.

Change-Id: I6f7b7cac612ef2c678e9d49f5ffa60eb53d3ae29

* Fix stages for test and add clock64 to profile.

Change-Id: Iffaf7324beaa910ce9ee56f47ae289de98f1a267

* Use __byte_perm to replace shift-and-or operations for faster integer merging.

* Split the uint2b convert.

Change-Id: I78da672ce8968e21f685285140ba546a161521b4

* Optimize convert of unscale.

Change-Id: I6795da1cdf5e8ab38ddaa9836240921b5312913a

* Minor optimization.

Change-Id: I1800aec34c3f4621abb02658208108f54da44d88

* Optimize mma pipeline and refine codes.

Change-Id: Id3075cf7b88f2813a11ccd1d3b49c62c978f36b8

* Add missing support.

Change-Id: Id65b7bc2c25fbb1a5b232c6bc9fb8c9093f691a8

* Accelerate FP16 dequantization performance

* Support tile shape as Xx64x64.

Change-Id: Ib8fd37e1ba1d06f7d11f2956e7f1367b0a92bcac

* Remove debugging codes and minor optimization.

Change-Id: I6b79bd56a6e8dd823efc169967ecd3cc9a43baf4

* Fix offset bug.

Change-Id: Id7aeb91e99d6f51836f2aff22187b4f79607395e

* Fix typo.

Change-Id: I19dde93fc1c1f7e19605905c90dc46298e203952

* Restore some codes and remove some debugging logs.

Change-Id: I8d44daf82ad1c6f8174134d195e7b3fe9a3afdfb

---------

Co-authored-by: baoqiwen <baoqiwen@baidu.com>
2025-07-28 10:32:43 +08:00

379 lines
19 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "cutlass/numeric_conversion.h"
#include "group_swiglu_with_masked.h"
#include "helper.h"
#include "moe/fast_hardamard_kernel.h"
#include "moe/fused_moe_helper.h"
template <typename DataT, typename NvType, typename WeightSavedT, cutlass::WintQuantMethod QuantMethod>
void WeightOnlyMoeFFNKernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::Tensor* up_gate_proj_bias,
const paddle::Tensor* up_gate_proj_super_scale,
const paddle::Tensor* down_proj_super_scale,
const paddle::Tensor* up_gate_proj_local_scale,
const paddle::Tensor* up_gate_proj_code_scale,
const paddle::Tensor* up_gate_proj_code_zp,
const paddle::Tensor* down_proj_local_scale,
const paddle::Tensor* down_proj_code_scale,
const paddle::Tensor* down_proj_code_zp,
paddle::Tensor fc1_out,
paddle::Tensor ffn_out,
const int64_t total_rows_in_ll_else_minus1,
const int64_t actual_total_rows,
const int64_t inter_size,
const int64_t hidden_size,
const int num_experts,
bool used_in_ep_low_latency) {
using namespace phi;
using WeightOnlyTraits = cutlass::WintQuantTraits<NvType, QuantMethod>;
using WeightType = typename WeightOnlyTraits::WeightType;
typename WeightOnlyTraits::Arguments up_gate_proj_quant_args;
typename WeightOnlyTraits::Arguments down_proj_quant_args;
if constexpr (QuantMethod == cutlass::WintQuantMethod::kWeightOnlyInt2) {
up_gate_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(up_gate_proj_local_scale->data<uint8_t>());
up_gate_proj_quant_args.code_scale_ptr = const_cast<float*>(up_gate_proj_code_scale->data<float>());
up_gate_proj_quant_args.code_zp_ptr = const_cast<float*>(up_gate_proj_code_zp->data<float>());
down_proj_quant_args.local_scale_ptr = const_cast<uint8_t*>(down_proj_local_scale->data<uint8_t>());
down_proj_quant_args.code_scale_ptr = const_cast<float*>(down_proj_code_scale->data<float>());
down_proj_quant_args.code_zp_ptr = const_cast<float*>(down_proj_code_zp->data<float>());
}
auto moe_gemm_runner = MoeGemmRunner<NvType, WeightOnlyTraits>();
auto stream = permute_input.stream();
moe_gemm_runner.moe_gemm_bias_act(
reinterpret_cast<const NvType*>(permute_input.data<DataT>()),
reinterpret_cast<const WeightType*>(up_gate_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(up_gate_proj_super_scale ? up_gate_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<const NvType*>(up_gate_proj_bias ? up_gate_proj_bias->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(fc1_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
up_gate_proj_quant_args,
"none",
stream);
paddle::Tensor act_out;
if (used_in_ep_low_latency) {
act_out = GroupSwigluWithMasked(fc1_out, tokens_expert_prefix_sum);
} else {
act_out = paddle::experimental::swiglu(fc1_out, nullptr);
}
moe_gemm_runner.moe_gemm(
reinterpret_cast<const NvType*>(act_out.data<DataT>()),
reinterpret_cast<const WeightType*>(down_proj_weight.data<WeightSavedT>()),
reinterpret_cast<const NvType*>(down_proj_super_scale ? down_proj_super_scale->data<DataT>() : nullptr),
reinterpret_cast<NvType*>(ffn_out.data<DataT>()),
const_cast<int64_t*>(tokens_expert_prefix_sum.data<int64_t>()),
total_rows_in_ll_else_minus1,
actual_total_rows,
hidden_size,
inter_size / 2,
num_experts,
down_proj_quant_args,
stream);
}
template <paddle::DataType T>
void MoeFFNWint2Kernel(const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
paddle::Tensor ffn_out,
bool used_in_ep_low_latency) {
using namespace phi;
using data_t = typename PDTraits<T>::data_t;
using NvType = typename PDTraits<T>::DataType;
auto place = permute_input.place();
assert(permute_input.dims().size() == 3 || permute_input.dims().size() == 2);
assert(up_gate_proj_weight.dims().size() == 3);
const int num_experts = up_gate_proj_weight.dims()[0];
const int hidden_size = permute_input.dims()[permute_input.dims().size() - 1];
int inter_dim = up_gate_proj_weight.dims()[1] * up_gate_proj_weight.dims()[2] / hidden_size;
const int64_t inter_size = inter_dim * 4;
int num_experts_ = num_experts;
int num_max_tokens_per_expert = 0;
int expanded_active_expert_rows = 0;
paddle::Tensor fc1_out_tensor;
if (permute_input.dims().size() == 3) {
num_experts_ = permute_input.dims()[0];
assert(num_experts == num_experts_);
num_max_tokens_per_expert = permute_input.dims()[1];
expanded_active_expert_rows = num_experts_ * num_max_tokens_per_expert;
fc1_out_tensor = GetEmptyTensor(
{num_experts_, num_max_tokens_per_expert, inter_size}, T, place);
} else {
expanded_active_expert_rows = permute_input.dims()[0];
fc1_out_tensor = GetEmptyTensor(
{expanded_active_expert_rows, inter_size}, T, place);
}
// This is a trick.
// expanded_active_expert_rows is not needed in variable group gemm.
// but is needed in accommodating deepep low latency mode
const int64_t total_rows_in_ll_else_minus1 = used_in_ep_low_latency ? expanded_active_expert_rows : -1;
// When we tune the optimal configuration, we need the actual total_rows.
const int64_t actual_total_rows = expanded_active_expert_rows;
WeightOnlyMoeFFNKernel<data_t, NvType, uint8_t, cutlass::WintQuantMethod::kWeightOnlyInt2>(
permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
const_cast<paddle::Tensor*>(up_gate_proj_bias.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(up_gate_proj_code_zp.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_local_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_scale.get_ptr()),
const_cast<paddle::Tensor*>(down_proj_code_zp.get_ptr()),
fc1_out_tensor,
ffn_out,
total_rows_in_ll_else_minus1,
actual_total_rows,
inter_size,
hidden_size,
num_experts,
used_in_ep_low_latency);
}
paddle::Tensor MoeExpertFFNWint2Func(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
const auto dtype = permute_input.dtype();
auto ffn_out = paddle::empty_like(permute_input, dtype);
switch (dtype) {
case paddle::DataType::BFLOAT16:
MoeFFNWint2Kernel<paddle::DataType::BFLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
case paddle::DataType::FLOAT16:
MoeFFNWint2Kernel<paddle::DataType::FLOAT16>(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
ffn_out,
used_in_ep_low_latency);
break;
default:
PD_THROW("Unsupported data type for MoeExpertFFN");
}
return ffn_out;
}
std::vector<paddle::Tensor> MoeExpertFFNWint2(
const paddle::Tensor& permute_input,
const paddle::Tensor& tokens_expert_prefix_sum,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_local_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_code_zp,
const paddle::optional<paddle::Tensor>& down_proj_local_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_scale,
const paddle::optional<paddle::Tensor>& down_proj_code_zp,
const bool used_in_ep_low_latency) {
return {MoeExpertFFNWint2Func(permute_input,
tokens_expert_prefix_sum,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_scale,
up_gate_proj_local_scale,
up_gate_proj_code_scale,
up_gate_proj_code_zp,
down_proj_local_scale,
down_proj_code_scale,
down_proj_code_zp,
used_in_ep_low_latency)};
}
std::vector<std::vector<int64_t>> MoeExpertFFNWint2InferShape(
const std::vector<int64_t>& permute_input_shape,
const std::vector<int64_t>& tokens_expert_prefix_sum_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_code_zp_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_local_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_code_zp_shape,
const bool used_in_ep_low_latency) {
return {permute_input_shape};
}
std::vector<paddle::DataType> MoeExpertFFNWint2InferDtype(
const paddle::DataType &permute_input_dtype,
const paddle::DataType &tokens_expert_prefix_sum_dtype,
const paddle::DataType &up_gate_proj_weight_dtype,
const paddle::DataType &down_proj_weight_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &up_gate_proj_code_zp_dtype,
const paddle::optional<paddle::DataType> &down_proj_local_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_scale_dtype,
const paddle::optional<paddle::DataType> &down_proj_code_zp_dtype,
const bool used_in_ep_low_latency) {
return {permute_input_dtype};
}
/**
* @brief Weight-Only Quantized Mixture of Experts (MoE) Feed-Forward Network Operator
*
* This operator performs the expert computation in MoE architecture, including:
* 1. First linear transformation (up_gate_proj) with optional quantization
* 2. SwiGLU activation function
* 3. Second linear transformation (down_proj) with optional quantization
*
* Supports multiple quantization methods including weight-only int4/int8 and w4a8 quantization.
*
* Inputs:
* - permute_input: Permuted input tensor organized by expert
* Shape: [total_tokens * top_k, hidden_size]
* dtype: bfloat16/float16 (or int8 for w4a8)
* - tokens_expert_prefix_sum: Prefix sum array of token counts per expert for group_gemm
* Shape: [num_experts]
* dtype: int64
* - up_gate_proj_weight: First FFN layer weights
* Shape: [num_experts, inter_size * 2, hidden_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - down_proj_weight: Second FFN layer weights
* Shape: [num_experts, hidden_size, inter_size]
* dtype: Same as input (unquantized) or int8 (quantized)
* - up_gate_proj_bias: Optional bias for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - up_gate_proj_scale: Quantization scales for first FFN layer
* Shape: [num_experts, inter_size * 2]
* dtype: Same as input
* - down_proj_scale: Quantization scales for second FFN layer
* Shape: [num_experts, hidden_size]
* dtype: Same as input
*
* Outputs:
* - output_tensor: Output tensor after MoE FFN computation
* Shape: Same as permute_input
* dtype: Same as input (or up_gate_proj_scale dtype for w4a8)
*
* Attributes:
* - used_in_ep_low_latency: Whether running in low latency mode
* Affects activation function implementation
*
* Note:
* - Low latency mode uses specialized grouped SwiGLU implementation
*/
PD_BUILD_STATIC_OP(moe_expert_ffn_wint2)
.Inputs({"permute_input",
"tokens_expert_prefix_sum",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_scale"),
paddle::Optional("up_gate_proj_local_scale"),
paddle::Optional("up_gate_proj_code_scale"),
paddle::Optional("up_gate_proj_code_zp"),
paddle::Optional("down_proj_local_scale"),
paddle::Optional("down_proj_code_scale"),
paddle::Optional("down_proj_code_zp")})
.Outputs({"output_tensor"})
.Attrs({"used_in_ep_low_latency:bool"})
.SetKernelFn(PD_KERNEL(MoeExpertFFNWint2))
.SetInferShapeFn(PD_INFER_SHAPE(MoeExpertFFNWint2InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoeExpertFFNWint2InferDtype));