Files
FastDeploy/custom_ops/gpu_ops/moe/moe_deepgemm_permute.cu
2025-07-19 23:19:27 +08:00

116 lines
4.4 KiB
Plaintext

// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
// topk warps
template<typename T, int VecSize>
__global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) {
AlignedVector<T, VecSize> in_vec;
const int bid = blockIdx.x;
const int wid = threadIdx.x / 32;
const int tid = threadIdx.x % 32;
for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) {
const int tgt_expert_id = topk_idx[token_idx * topk + wid];
int tgt_expert_token;
if (tid == 0) {
tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1);
permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token;
}
tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0);
for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) {
Load<T, VecSize>(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec);
Store<T, VecSize>(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize);
}
}
}
template <paddle::DataType D>
std::vector<paddle::Tensor> MoEDeepGEMMPermuteDispatch(
const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
const int num_experts,
const int max_tokens_per_expert
) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
const int token_num = x.shape()[0];
const int hidden = x.shape()[1];
const int topk = topk_idx.shape()[1];
auto place = x.place();
auto stream = x.stream();
auto token_nums_per_expert = GetEmptyTensor({num_experts}, paddle::DataType::INT32, place);
auto permute_indices_per_token = GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(token_nums_per_expert.data<int32_t>(), 0, num_experts * sizeof(int32_t), stream));
auto permute_output = GetEmptyTensor({num_experts, max_tokens_per_expert, hidden}, x.dtype(), place);
auto permute_output_data = permute_output.data<data_t>();
constexpr int VecSize = 16 / sizeof(data_t);
int blocks = 32 * topk;
int grids = min(132 * 4, token_num);
int num_vecs = hidden / VecSize;
assert(blocks <= 1024);
MoEDeepGEMMPermuteKernel<DataType_, VecSize><<<grids, blocks, 0, stream>>>(
reinterpret_cast<DataType_*>(permute_output_data),
token_nums_per_expert.data<int32_t>(),
permute_indices_per_token.data<int32_t>(),
reinterpret_cast<const DataType_ *>(x.data<data_t>()),
topk_idx.data<int64_t>(),
token_num, topk, num_vecs,
hidden, max_tokens_per_expert
);
return {permute_output, token_nums_per_expert, permute_indices_per_token};
}
std::vector<paddle::Tensor> MoEDeepGEMMPermute(
const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
const int num_experts,
const int max_tokens_per_expert
) {
switch (x.dtype()) {
case paddle::DataType::BFLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::BFLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert
);
case paddle::DataType::FLOAT16:
return MoEDeepGEMMPermuteDispatch<paddle::DataType::FLOAT16>(
x, topk_idx, num_experts, max_tokens_per_expert
);
default:
PD_THROW("Unsupported data type");
}
}
PD_BUILD_STATIC_OP(moe_deepgemm_permute)
.Inputs({"x", "topk_idx"})
.Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"})
.Attrs({"num_experts: int", "max_tokens_per_expert: int"})
.SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));