Files
FastDeploy/custom_ops/gpu_ops/moe/fused_moe_helper.h
2025-09-10 10:56:15 +08:00

366 lines
15 KiB
C++

/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass_extensions/wint_type_traits.h"
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
#include "moe/fused_moe_op.h"
namespace phi {
template <typename T, int VecSize>
__global__ void moe_token_type_ids_kernel(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts, const int k) {
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
if (moe_token_index >= num_rows) {
return;
}
gating_output[moe_token_index * 2] =
gating_output[moe_token_index * 2] +
(moe_token_type_ids_out[moe_token_index]) * -1e10;
gating_output[moe_token_index * 2 + 1] =
gating_output[moe_token_index * 2 + 1] +
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
}
template <typename T>
void moe_token_type_ids_kernelLauncher(T *gating_output,
const int *moe_token_type_ids_out,
const int num_rows,
const int num_experts, const int k,
cudaStream_t stream) {
const int blocks = num_rows * k / 512 + 1;
const int threads = 512;
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
}
template <typename T, typename NvType> class MoeHelper {
public:
using Fp16Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kNone>;
using Int8Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt8>;
using Int4Traits = cutlass::WintQuantTraits<NvType, cutlass::WintQuantMethod::kWeightOnlyInt4>;
MoeHelper(
const std::string gemm_method,
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner,
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner,
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner,
int layernum = 0)
: gemm_method_(gemm_method), fp16_moe_gemm_runner_(fp16_moe_gemm_runner),
int8_moe_gemm_runner_(int8_moe_gemm_runner),
int4_moe_gemm_runner_(int4_moe_gemm_runner), layernum_(layernum) {}
// -------- getWorkspaceSize -------- //
template <typename KeyT>
size_t getWorkspaceSize(const int64_t num_rows, const int64_t hidden_size,
const int64_t inter_size, const int64_t num_experts,
const int64_t k) {
const size_t buf_size = AlignTo16(k * num_rows * hidden_size);
const size_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const size_t padded_experts = AlignTo16(num_experts);
const size_t num_moe_inputs = AlignTo16(k * num_rows);
// softmax output, permuted_rows and permuted_experts have moved to outside
// of moe kernel, allocate them in Encoder or Decoder before invoking
// FfnLayer forward.
size_t total_ws_bytes =
5 * num_moe_inputs *
sizeof(int); // source_rows_, permuted_rows_, permuted_experts_
total_ws_bytes += buf_size * sizeof(KeyT); // permuted_data
total_ws_bytes +=
padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_
const size_t bytes_for_fc1_result = interbuf_size * sizeof(KeyT);
const size_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(num_rows));
sorter_.update_num_experts(num_experts);
int64_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result;
if (sorter_ws_size_bytes > bytes_for_fc1_result) {
int64_t remaining_bytes =
AlignTo16(sorter_ws_size_bytes - bytes_for_fc1_result);
bytes_for_intermediate_and_sorting += remaining_bytes;
}
total_ws_bytes +=
bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub
// sorting workspace
int64_t num_softmax_outs = 0;
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
num_softmax_outs = AlignTo16(num_rows * num_experts);
}
total_ws_bytes += num_softmax_outs * sizeof(float);
return total_ws_bytes;
}
void
ComputeFFN(const paddle::Tensor *input, const paddle::Tensor *gate_weight,
const paddle::Tensor *up_gate_proj_weight,
const paddle::Tensor *up_gate_proj_scale, const paddle::Tensor *up_gate_proj_bias,
const paddle::Tensor *down_proj_weight,
const paddle::Tensor *down_proj_scale, const paddle::Tensor *down_proj_bias,
const paddle::Tensor *moe_token_type_ids, const int moe_topk,
const bool group_moe, const bool norm_topk_prob,
const float routed_scaling_factor, const std::string moe_type,
paddle::Tensor *output) {
auto *input_activations = input->data<T>();
auto *gating_weights = gate_weight->data<float>();
const T *fc1_expert_biases = up_gate_proj_bias ? up_gate_proj_bias->data<T>() : nullptr;
const T *fc2_expert_biases = down_proj_bias ? down_proj_bias->data<T>() : nullptr;
auto *output_ = output->data<T>();
auto stream = input->stream();
auto place = input->place();
auto input_type = input->dtype();
auto input_dims = input->dims();
auto up_gate_proj_dims = up_gate_proj_weight->dims();
int64_t token_num = 0;
if (input_dims.size() == 3) {
token_num = input_dims[0] * input_dims[1];
} else {
token_num = input_dims[0];
}
const int64_t num_rows = token_num;
const int64_t hidden_size = up_gate_proj_dims[1];
int64_t inter_dim = 0;
if (moe_type == "qkv") {
inter_dim = up_gate_proj_dims[2] * up_gate_proj_dims[3] * up_gate_proj_dims[4];
} else {
inter_dim = up_gate_proj_dims[2];
}
if (gemm_method_ == "weight_only_int4") {
inter_dim = inter_dim * 2;
}
const int64_t inter_size = inter_dim;
const int64_t num_experts = up_gate_proj_dims[0];
const int64_t k = moe_topk;
int64_t bytes =
getWorkspaceSize<T>(num_rows, hidden_size, inter_size, num_experts, k);
// Pointers
int *expert_for_source_row;
int *source_rows_;
int *permuted_rows_;
int *permuted_experts_;
int *expanded_source_row_to_expanded_dest_row;
T *permuted_data_;
int64_t *total_rows_before_expert_;
T *fc1_result_;
float *softmax_out_;
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes}, paddle::DataType::INT8, place);
int8_t *ws_ptr = ws_ptr_tensor.data<int8_t>();
const int64_t buf_size = AlignTo16(k * num_rows * hidden_size);
const int64_t interbuf_size = AlignTo16(k * num_rows * inter_size);
const int64_t padded_experts = AlignTo16(num_experts);
const int64_t num_moe_inputs = AlignTo16(k * num_rows);
expert_for_source_row = reinterpret_cast<int *>(ws_ptr);
source_rows_ = expert_for_source_row + num_moe_inputs;
permuted_rows_ = source_rows_ + num_moe_inputs;
permuted_experts_ = permuted_rows_ + num_moe_inputs;
expanded_source_row_to_expanded_dest_row =
permuted_experts_ + num_moe_inputs;
permuted_data_ = reinterpret_cast<T *>(
expanded_source_row_to_expanded_dest_row + num_moe_inputs);
total_rows_before_expert_ =
reinterpret_cast<int64_t *>(permuted_data_ + buf_size);
fc1_result_ =
reinterpret_cast<T *>(total_rows_before_expert_ + padded_experts);
const bool is_pow_2 =
(num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
if (!is_pow_2 || num_experts > 256) {
softmax_out_ = reinterpret_cast<float *>(fc1_result_ + interbuf_size);
} else {
softmax_out_ = nullptr;
}
paddle::Tensor expert_scales_float_tensor =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
float *expert_scales_float = expert_scales_float_tensor.data<float>();
float *softmax_max_prob = nullptr;
if (group_moe) {
paddle::Tensor softmax_max_prob_tensor = GetEmptyTensor(
{num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
// (TODO: check fill success ?)
paddle::experimental::fill(softmax_max_prob_tensor, 0.f);
softmax_max_prob = softmax_max_prob_tensor.data<float>();
}
paddle::Tensor fc1_out_tensor =
GetEmptyTensor({num_rows * k, inter_size}, input_type, place);
T *fc1_out = fc1_out_tensor.data<T>();
auto input_cast_tensor =
paddle::experimental::cast(*input, paddle::DataType::FLOAT32);
auto gate_tensor =
paddle::experimental::matmul(input_cast_tensor, *gate_weight);
float *gating_output = gate_tensor.data<float>();
if (moe_token_type_ids) {
auto *moe_token_type_ids_out = moe_token_type_ids->data<int>();
moe_token_type_ids_kernelLauncher<float>(gating_output,
moe_token_type_ids_out, num_rows,
num_experts, k, stream);
}
topk_gating_softmax_kernelLauncher<float, int>(
gating_output, nullptr, expert_scales_float, softmax_out_,
expert_for_source_row, source_rows_, softmax_max_prob, num_rows,
num_experts, k, group_moe, stream);
const int64_t sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(int64_t(k * num_rows)));
sorter_.run(fc1_result_, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts_, source_rows_, permuted_rows_, k * num_rows,
false, stream);
initialize_moe_routing_kernelLauncher(
input_activations, permuted_data_, permuted_rows_, nullptr, nullptr,
expanded_source_row_to_expanded_dest_row, num_rows, num_rows,
hidden_size, k, stream);
const int64_t expanded_active_expert_rows = k * num_rows;
compute_total_rows_before_expert(permuted_experts_,
expanded_active_expert_rows, num_experts,
total_rows_before_expert_, stream);
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments up_gate_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const uint8_t *>(up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments up_gate_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const cutlass::uint4b_t *>(
up_gate_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(up_gate_proj_scale->data<T>()),
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
} else {
typename Fp16Traits::Arguments up_gate_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm_bias_act(
reinterpret_cast<NvType *>(permuted_data_),
reinterpret_cast<const NvType *>(up_gate_proj_weight->data<T>()), nullptr,
reinterpret_cast<const NvType *>(fc1_expert_biases),
reinterpret_cast<NvType *>(fc1_out), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, inter_size, hidden_size, num_experts,
up_gate_proj_quant_args, "none", stream);
}
if (moe_type == "ffn") {
auto act_out_tensor =
paddle::experimental::swiglu(fc1_out_tensor, nullptr);
auto act_out = act_out_tensor.data<T>();
paddle::Tensor fc2_output_tensor =
GetEmptyTensor({k * num_rows, hidden_size}, input_type, place);
T *fc2_result = fc2_output_tensor.data<T>();
if (gemm_method_ == "weight_only_int8") {
typename Int8Traits::Arguments down_proj_quant_args;
int8_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const uint8_t *>(down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
} else if (gemm_method_ == "weight_only_int4") {
typename Int4Traits::Arguments down_proj_quant_args;
int4_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const cutlass::uint4b_t *>(
down_proj_weight->data<int8_t>()),
reinterpret_cast<const NvType *>(down_proj_scale->data<T>()),
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
} else {
typename Fp16Traits::Arguments down_proj_quant_args;
fp16_moe_gemm_runner_->moe_gemm(
reinterpret_cast<NvType *>(act_out),
reinterpret_cast<const NvType *>(down_proj_weight->data<T>()), nullptr,
reinterpret_cast<NvType *>(fc2_result), total_rows_before_expert_,
-1, // useless
expanded_active_expert_rows, hidden_size, inter_size / 2,
num_experts, down_proj_quant_args, stream);
}
finalize_moe_routing_kernelLauncher(
fc2_result, output_, fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, hidden_size, k, static_cast<int>(1), norm_topk_prob,
routed_scaling_factor, stream);
} else {
finalize_moe_routing_kernelLauncher(
// fc2_result,
fc1_out, output_,
fc1_expert_biases, // fc2_expert_biases,
reinterpret_cast<float *>(expert_scales_float),
expanded_source_row_to_expanded_dest_row, expert_for_source_row,
num_rows, inter_size, k, static_cast<int>(0), norm_topk_prob,
routed_scaling_factor, stream);
}
}
private:
std::string gemm_method_;
MoeGemmRunner<NvType, Fp16Traits> *fp16_moe_gemm_runner_;
MoeGemmRunner<NvType, Int8Traits> *int8_moe_gemm_runner_;
MoeGemmRunner<NvType, Int4Traits> *int4_moe_gemm_runner_;
int layernum_;
CubKeyValueSorter sorter_;
};
} // namespace phi