// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "helper.h" // topk warps template __global__ void MoEDeepGEMMPermuteKernel(T* out, int* token_nums_per_expert, int* permute_indices_per_token, const T* x, const int64_t* topk_idx, const int token_num, const int topk, const int num_vecs, const int hidden, const int max_tokens_per_expert) { AlignedVector in_vec; const int bid = blockIdx.x; const int wid = threadIdx.x / 32; const int tid = threadIdx.x % 32; for (int token_idx = bid; token_idx < token_num; token_idx += gridDim.x) { const int tgt_expert_id = topk_idx[token_idx * topk + wid]; int tgt_expert_token; if (tid == 0) { tgt_expert_token = atomicAdd(token_nums_per_expert + tgt_expert_id, 1); permute_indices_per_token[token_idx * topk + wid] = tgt_expert_token; } tgt_expert_token = __shfl_sync(0xFFFFFFFF, tgt_expert_token, 0); for (int hidden_vec_id = tid; hidden_vec_id < num_vecs; hidden_vec_id += 32) { Load(x + token_idx * hidden + hidden_vec_id * VecSize, &in_vec); Store(in_vec, out + tgt_expert_id * max_tokens_per_expert * hidden + tgt_expert_token * hidden + hidden_vec_id * VecSize); } } } template std::vector MoEDeepGEMMPermuteDispatch( const paddle::Tensor& x, const paddle::Tensor& topk_idx, const int num_experts, const int max_tokens_per_expert ) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; const int token_num = x.shape()[0]; const int hidden = x.shape()[1]; const int topk = topk_idx.shape()[1]; auto place = x.place(); auto stream = x.stream(); auto token_nums_per_expert = GetEmptyTensor({num_experts}, paddle::DataType::INT32, place); auto permute_indices_per_token = GetEmptyTensor({token_num, topk}, paddle::DataType::INT32, place); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(token_nums_per_expert.data(), 0, num_experts * sizeof(int32_t), stream)); auto permute_output = GetEmptyTensor({num_experts, max_tokens_per_expert, hidden}, x.dtype(), place); auto permute_output_data = permute_output.data(); constexpr int VecSize = 16 / sizeof(data_t); int blocks = 32 * topk; int grids = min(132 * 4, token_num); int num_vecs = hidden / VecSize; assert(blocks <= 1024); MoEDeepGEMMPermuteKernel<<>>( reinterpret_cast(permute_output_data), token_nums_per_expert.data(), permute_indices_per_token.data(), reinterpret_cast(x.data()), topk_idx.data(), token_num, topk, num_vecs, hidden, max_tokens_per_expert ); return {permute_output, token_nums_per_expert, permute_indices_per_token}; } std::vector MoEDeepGEMMPermute( const paddle::Tensor& x, const paddle::Tensor& topk_idx, const int num_experts, const int max_tokens_per_expert ) { switch (x.dtype()) { case paddle::DataType::BFLOAT16: return MoEDeepGEMMPermuteDispatch( x, topk_idx, num_experts, max_tokens_per_expert ); case paddle::DataType::FLOAT16: return MoEDeepGEMMPermuteDispatch( x, topk_idx, num_experts, max_tokens_per_expert ); default: PD_THROW("Unsupported data type"); } } PD_BUILD_STATIC_OP(moe_deepgemm_permute) .Inputs({"x", "topk_idx"}) .Outputs({"permute_output", "token_nums_per_expert", "permute_indices_per_token"}) .Attrs({"num_experts: int", "max_tokens_per_expert: int"}) .SetKernelFn(PD_KERNEL(MoEDeepGEMMPermute));