Files
FastDeploy/custom_ops/gpu_ops/moe/fused_moe.cu
Yuanle Liu 61b3997b85
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
refactor rl get_name_mappings_to_training (#2847)
* refactor rl get_name_mappings_to_training

* fix tp>1

* change variable name(ffn1->up_gate_proj/ffn2->down_proj)

* change variable name(linear_weight->weight/linear_bias->bias)

* add rl names mapping for vl

* fix ernie 0.3B error

* fix develop code

* fix
2025-07-15 07:31:42 -07:00

247 lines
11 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/wint_type_traits.h"
#include "helper.h"
#include "moe/fused_moe_helper.h"
namespace phi {
__global__ void compute_total_rows_before_expert_kernel(
int* sorted_experts,
const int64_t sorted_experts_len,
const int64_t num_experts,
int64_t* total_rows_before_expert) {
// First, compute the global tid. We only need 1 thread per expert.
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
if (expert >= num_experts) return;
// This should construct the last index where each expert occurs.
total_rows_before_expert[expert] =
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
}
void compute_total_rows_before_expert(int* sorted_indices,
const int64_t total_indices,
const int64_t num_experts,
int64_t* total_rows_before_expert,
cudaStream_t stream) {
const int threads = std::min(int64_t(1024), num_experts);
const int blocks = (num_experts + threads - 1) / threads;
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, total_rows_before_expert);
}
} // namespace phi
template <paddle::DataType T>
void FusedMoeKernel(const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const std::string& quant_method,
const int moe_topk,
const bool group_moe,
const bool norm_topk_prob,
paddle::Tensor* output) {
using namespace phi;
typedef PDTraits<T> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
auto* output_data = output->data<data_t>();
auto fp16_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kNone>>();
auto int8_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt8>>();
auto int4_moe_gemm_runner = MoeGemmRunner<DataType_, cutlass::WintQuantTraits<DataType_, cutlass::WintQuantMethod::kWeightOnlyInt4>>();
using NvType = typename traits_::DataType;
auto moe_compute = MoeHelper<data_t, NvType>(quant_method,
&fp16_moe_gemm_runner,
&int8_moe_gemm_runner,
&int4_moe_gemm_runner);
moe_compute.ComputeFFN(&input,
&gate_weight,
&up_gate_proj_weight,
up_gate_proj_scale ? up_gate_proj_scale.get_ptr() : nullptr,
up_gate_proj_bias ? up_gate_proj_bias.get_ptr() : nullptr,
&down_proj_weight,
down_proj_scale ? down_proj_scale.get_ptr() : nullptr,
down_proj_bias ? down_proj_bias.get_ptr() : nullptr,
nullptr,
moe_topk,
group_moe,
norm_topk_prob,
1.0, // ComputeFFN
"ffn",
output);
}
paddle::Tensor FusedExpertMoeFunc(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
const auto input_type = input.dtype();
auto output = paddle::empty_like(input);
switch (input_type) {
case paddle::DataType::BFLOAT16:
FusedMoeKernel<paddle::DataType::BFLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
case paddle::DataType::FLOAT16:
FusedMoeKernel<paddle::DataType::FLOAT16>(input,
gate_weight,
up_gate_proj_weight,
up_gate_proj_scale,
up_gate_proj_bias,
down_proj_weight,
down_proj_scale,
down_proj_bias,
quant_method,
moe_topk,
group_moe,
norm_topk_prob,
&output);
break;
default:
PD_THROW("Unsupported data type for FusedMoeKernel");
}
return output;
}
std::vector<paddle::Tensor> FusedExpertMoe(
const paddle::Tensor& input,
const paddle::Tensor& gate_weight,
const paddle::Tensor& up_gate_proj_weight,
const paddle::Tensor& down_proj_weight,
const paddle::optional<paddle::Tensor>& up_gate_proj_bias,
const paddle::optional<paddle::Tensor>& up_gate_proj_scale,
const paddle::optional<paddle::Tensor>& down_proj_bias,
const paddle::optional<paddle::Tensor>& down_proj_scale,
const std::string& quant_method,
const int moe_topk,
const bool norm_topk_prob,
const bool group_moe) {
return {FusedExpertMoeFunc(input,
gate_weight,
up_gate_proj_weight,
down_proj_weight,
up_gate_proj_bias,
up_gate_proj_scale,
down_proj_bias,
down_proj_scale,
quant_method,
moe_topk,
norm_topk_prob,
group_moe)};
}
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& gate_weight_shape,
const std::vector<int64_t>& up_gate_proj_weight_shape,
const std::vector<int64_t>& down_proj_weight_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& up_gate_proj_scale_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_bias_shape,
const paddle::optional<std::vector<int64_t>>& down_proj_scale_shape) {
return {input_shape};
}
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
const paddle::DataType& input_dtype,
const paddle::DataType& gate_weight_dtype,
const paddle::DataType& up_gate_proj_weight_dtype,
const paddle::DataType& down_proj_weight_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_bias_dtype,
const paddle::optional<paddle::DataType>& up_gate_proj_scale_dtype,
const paddle::optional<paddle::DataType>& down_proj_bias_dtype,
const paddle::optional<paddle::DataType>& down_proj_scale_dtype) {
return {input_dtype};
}
/**
* @brief Fused Mixture-of-Experts (MoE) Operator
*
* This operator combines three key MoE operations into a single optimized kernel:
* 1. moe_dispatch - Routes tokens to top-k experts using gating network
* 2. moe_ffn - Processes tokens through parallel expert FFNs
* 3. moe_reduce - Combines expert outputs with routing weights
*
* Key Features:
* - Supports both dense and quantized expert weights
* - Optimized for GPU execution with fused operations
*
* Mathematical Formulation:
* output = ∑_i^topk(softmax(gate(x))_i * FFN_i(x)
*
* Reference Components:
* moe_dispatch: Selects top-k experts per token and generates permutation indices
* moe_ffn: Applies SwiGLU activation expert networks in parallel
* moe_reduce: Combines weighted expert outputs and restores original token order
*
* Performance Notes:
* - Recommended hidden_size multiples of 128 for optimal memory alignment
* - For best throughput, num_experts should be powers of 2
*/
PD_BUILD_STATIC_OP(fused_expert_moe)
.Inputs({"input",
"gate_weight",
"up_gate_proj_weight",
"down_proj_weight",
paddle::Optional("up_gate_proj_bias"),
paddle::Optional("up_gate_proj_scale"),
paddle::Optional("down_proj_bias"),
paddle::Optional("down_proj_scale")})
.Outputs({"output"})
.Attrs({"quant_method:std::string",
"moe_topk:int",
"norm_topk_prob:bool",
"group_moe:bool"})
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));