Files
FastDeploy/custom_ops/gpu_ops/moe/moe_topk_select.cu
2025-06-29 23:29:37 +00:00

236 lines
9.5 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Ignore CUTLASS warnings about type punning
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma once
#include "moe/fused_moe_helper.h"
#include "moe/fused_moe_op.h"
#pragma GCC diagnostic pop
#include "helper.h"
using namespace phi;
template <typename T>
void moe_topk_select_kernel(const T* input,
const T* bias,
T* output,
T* softmax,
int64_t* indices,
int* source_row,
T* softmax_max_prob,
const int64_t num_rows,
const int64_t num_experts,
const int64_t k,
cudaStream_t stream,
const bool apply_norm_weight = false,
const bool enable_softmax_top_k_fused = false
) {
static constexpr int WARPS_PER_TB = 4;
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;
if(bias != nullptr || apply_norm_weight) tem_num_experts = 0;
switch (tem_num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256)
default: {
static constexpr int TPB = 256;
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
if (!enable_softmax_top_k_fused) {
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
input, softmax, num_experts, num_rows);
if (apply_norm_weight) {
moe_top_k_normed<T, TPB>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_top_k<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
cudaGetLastError();
}
else {
assert(k<=TPB);
if (apply_norm_weight) {
moe_softmax_top_k_normed_fused<T, TPB>
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
} else {
moe_softmax_top_k_fused<T, TPB>
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
bias,
output,
indices,
source_row,
num_experts,
k,
num_rows);
}
}
}
}
}
std::vector<paddle::Tensor> MoETopKSelectKernel(
const paddle::Tensor& gating_logits,
const paddle::optional<paddle::Tensor>& bias,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused) {
auto stream = gating_logits.stream();
auto place = gating_logits.place();
int token_rows = 0;
auto gating_dims = gating_logits.dims();
const int expert_num = gating_dims[gating_dims.size() - 1];
if (gating_dims.size() == 3) {
token_rows = gating_dims[0] * gating_dims[1];
} else {
token_rows = gating_dims[0];
}
const int num_rows = token_rows;
// correspond to the weighted coefficients of the results from each expert.
auto topk_ids =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::INT64, place);
auto topk_weights =
GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place);
const int num_moe_inputs = AlignTo16(num_rows * moe_topk);
const int bytes = num_moe_inputs * sizeof(int);
CubKeyValueSorter sorter_;
sorter_.update_num_experts(expert_num);
const int sorter_ws_size_bytes =
AlignTo16(sorter_.getWorkspaceSize(moe_topk * num_rows));
const int sort_tmp_in_out_size = num_moe_inputs * 2 * sizeof(int);
paddle::Tensor ws_ptr_tensor =
GetEmptyTensor({bytes + sorter_ws_size_bytes + sort_tmp_in_out_size},
paddle::DataType::INT8,
place);
int8_t* ws_ptr = ws_ptr_tensor.data<int8_t>();
int* source_rows_ = reinterpret_cast<int*>(ws_ptr);
int64_t* topk_ids_data = topk_ids.data<int64_t>();
float* softmax_max_prob = nullptr;
float* softmax_out_;
const bool is_pow_2 =
(expert_num != 0) && ((expert_num & (expert_num - 1)) == 0);
paddle::Tensor softmax_buffer;
if (!is_pow_2 || expert_num > 256 || bias || apply_norm_weight) {
softmax_buffer = GetEmptyTensor(
{num_rows * expert_num}, paddle::DataType::FLOAT32, place);
softmax_out_ = softmax_buffer.data<float>();
} else {
softmax_out_ = nullptr;
}
moe_topk_select_kernel<float>(gating_logits.data<float>(),
bias ? bias.get().data<float>() : nullptr,
topk_weights.data<float>(),
softmax_out_,
topk_ids_data,
source_rows_,
softmax_max_prob,
num_rows,
expert_num,
moe_topk,
stream,
apply_norm_weight,
enable_softmax_top_k_fused);
return {topk_ids, topk_weights};
}
std::vector<std::vector<int64_t>> MoETopKSelectKernelInferShape(
const std::vector<int64_t>& gating_logits_shape,
const paddle::optional<std::vector<int64_t>>& bias_shape,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused) {
int token_rows = -1;
if (gating_logits_shape.size() == 3) {
token_rows = gating_logits_shape[0] * gating_logits_shape[1];
} else {
token_rows = gating_logits_shape[0];
}
const int num_rows = token_rows;
return {{num_rows, moe_topk},
{num_rows, moe_topk}};
}
std::vector<paddle::DataType> MoETopKSelectKernelInferDtype(
const paddle::DataType& gating_logits_dtype,
const paddle::optional<paddle::DataType>& bias_type,
const int moe_topk,
const bool apply_norm_weight,
const bool enable_softmax_top_k_fused) {
return {paddle::DataType::INT64,
paddle::DataType::FLOAT32};
}
PD_BUILD_STATIC_OP(moe_topk_select)
.Inputs({"gating_logits", paddle::Optional("bias")})
.Outputs({"topk_ids",
"topk_weights"})
.Attrs({"moe_topk:int", "apply_norm_weight:bool", "enable_softmax_top_k_fused:bool"})
.SetKernelFn(PD_KERNEL(MoETopKSelectKernel))
.SetInferShapeFn(PD_INFER_SHAPE(MoETopKSelectKernelInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MoETopKSelectKernelInferDtype));