// /* // * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & // * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 // * // * Licensed under the Apache License, Version 2.0 (the "License"); // * you may not use this file except in compliance with the License. // * You may obtain a copy of the License at // * // * http://www.apache.org/licenses/LICENSE-2.0 // * // * Unless required by applicable law or agreed to in writing, software // * distributed under the License is distributed on an "AS IS" BASIS, // * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // * See the License for the specific language governing permissions and // * limitations under the License. // */ #pragma once #include #include #include "moe/fused_moe_imp_op.h" #include "moe/fused_moe_helper.h" #include "cutlass/numeric_conversion.h" // Ignore CUTLASS warnings about type punning #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" #pragma GCC diagnostic ignored "-Wunused-function" // #include "paddle/phi/backends/gpu/gpu_info.h" #pragma GCC diagnostic pop #include "helper.h" #define WARP_SIZE 32 namespace phi { struct GpuLaunchConfig { dim3 block_per_grid; dim3 thread_per_block; }; inline GpuLaunchConfig Get1DBlocksAnd2DGridsMoe(const int64_t cols) { int blocks_x = cols; int blocks_y = 1; int blocks_z = 1; if (blocks_x > 1024) { blocks_y = 256; blocks_x = (blocks_x + blocks_y - 1) / blocks_y; } GpuLaunchConfig config; config.block_per_grid.x = blocks_x; config.block_per_grid.y = blocks_y; config.block_per_grid.z = blocks_z; return config; } constexpr static int FINALIZE_THREADS_PER_BLOCK = 256; template __host__ __device__ constexpr static U arrayConvert(T const& input) { using Type = typename U::Element; static_assert(T::kElements == U::kElements); U u; #pragma unroll for (int i = 0; i < U::kElements; i++) { u[i] = static_cast(input[i]); } return u; } // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing // the output in the softmax kernel when we extend this module to support // expert-choice routing. template __launch_bounds__(TPB) __global__ void group_moe_softmax(const T* input, T* output, T* softmax_max_prob, const int64_t num_cols, const int64_t softmax_num_rows) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ float normalizing_factor; __shared__ float float_max; __shared__ float max_out; int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; if (globalIdx >= softmax_num_rows) { return; } const int64_t thread_row_offset = globalIdx * num_cols; cub::Sum sum; float threadData(-FLT_MAX); for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData = max(static_cast(input[idx]), threadData); } const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { float_max = maxElem; } __syncthreads(); threadData = 0; for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData += exp((static_cast(input[idx]) - float_max)); } const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; } __syncthreads(); threadData = 0; for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; output[idx] = T(val); threadData = max(static_cast(T(val)), threadData); } const float maxOut = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { // group max probs max_out = 1.f / maxOut; softmax_max_prob[globalIdx] = T(max_out); } __syncthreads(); for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; // group softmax normalization output[idx] = output[idx] * static_cast(max_out); } } template __launch_bounds__(TPB) __global__ void moe_softmax(const T* input, T* output, const int64_t num_cols, const int64_t num_rows) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ float normalizing_factor; __shared__ float float_max; int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; if (globalIdx >= num_rows) { return; } const int64_t thread_row_offset = globalIdx * num_cols; cub::Sum sum; float threadData(-FLT_MAX); for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData = max(static_cast(input[idx]), threadData); } const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { float_max = maxElem; } __syncthreads(); threadData = 0; for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; threadData += exp((static_cast(input[idx]) - float_max)); } const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; } __syncthreads(); for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { const int idx = thread_row_offset + ii; const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; output[idx] = T(val); } } template __launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax, T* output, IdxT* indices, int* source_rows, T* softmax_max_prob, const int64_t num_experts, const int64_t k, const int64_t num_rows) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; cub_kvp thread_kvp; cub::ArgMax arg_max; const int block_row = blockIdx.x + blockIdx.y * gridDim.x; if (block_row >= num_rows) { return; } const bool should_process_row = true; const int thread_read_offset = block_row * num_experts; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; inp_kvp.value = inputs_after_softmax[idx]; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const IdxT prior_winning_expert = indices[k * block_row + prior_k]; if (prior_winning_expert == expert) { inp_kvp = thread_kvp; } } thread_kvp = arg_max(inp_kvp, thread_kvp); } const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; // restore normalized probes output[idx] = result_kvp.value / T(softmax_max_prob[idx]); indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; } __syncthreads(); } } template __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax, const T* bias, T* output, IdxT* indices, int* source_rows, const int64_t num_experts, const int64_t k, const int64_t num_rows) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; cub_kvp thread_kvp; cub::ArgMax arg_max; const int block_row = blockIdx.x + blockIdx.y * gridDim.x; if (block_row >= num_rows) { return; } const bool should_process_row = true; const int thread_read_offset = block_row * num_experts; T weight_sum = static_cast(0); T* row_outputs = nullptr; if constexpr (NormWeights){ extern __shared__ char smem[]; row_outputs = reinterpret_cast(smem); } for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const int prior_winning_expert = indices[k * block_row + prior_k]; if (prior_winning_expert == expert) { inp_kvp = thread_kvp; } } thread_kvp = arg_max(inp_kvp, thread_kvp); } const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; indices[idx] = should_process_row ? result_kvp.key : num_experts; source_rows[idx] = k_idx * num_rows + block_row; if constexpr (NormWeights){ T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; row_outputs[k_idx] = row_out; weight_sum += row_out; } else{ output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; } } __syncthreads(); } if constexpr (NormWeights){ if (threadIdx.x < WARP_SIZE) { weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); } if (threadIdx.x < k) { output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; } } } template __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input, const T* bias, T* output, IdxT* indices, int* source_rows, const int64_t num_experts, const int64_t k, const int64_t num_rows) { // softmax using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; __shared__ float normalizing_factor; __shared__ float float_max; int globalIdx = blockIdx.x + blockIdx.y * gridDim.x; if (globalIdx >= num_rows) { return; } const int64_t thread_row_offset = globalIdx * num_experts; const int64_t idx = thread_row_offset+threadIdx.x; cub::Sum sum; float threadData = (threadIdx.x < num_experts) ? static_cast(input[idx]) :(-FLT_MAX); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { float_max = maxElem; } __syncthreads(); float threadDataSub = threadData - float_max; float threadDataExp = exp(threadDataSub); const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum); if (threadIdx.x == 0) { normalizing_factor = 1.f / Z; } __syncthreads(); T val = T(threadDataExp * normalizing_factor); // top_k using cub_kvp = cub::KeyValuePair; using BlockReduceP = cub::BlockReduce; __shared__ typename BlockReduceP::TempStorage tmpStorageP; cub_kvp thread_kvp; cub::ArgMax arg_max; T weight_sum = static_cast(0); T* row_outputs = nullptr; if constexpr (NormWeights){ extern __shared__ char smem[]; row_outputs = reinterpret_cast(smem); } for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities if (threadIdx.x < num_experts) { cub_kvp inp_kvp; int expert = threadIdx.x; inp_kvp.key = expert; inp_kvp.value = bias ? val + bias[expert] : val; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const IdxT prior_winning_expert = indices[k * globalIdx + prior_k]; if (prior_winning_expert == expert) { inp_kvp = thread_kvp; } } thread_kvp = arg_max(inp_kvp, thread_kvp); } const cub_kvp result_kvp = BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int cur_idx = k * globalIdx + k_idx; indices[cur_idx] = result_kvp.key; source_rows[cur_idx] = k_idx * num_rows + globalIdx; if constexpr (NormWeights) { T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; row_outputs[k_idx] = row_out; weight_sum += row_out; } else { output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value; } } __syncthreads(); } if constexpr (NormWeights) { if (threadIdx.x < WARP_SIZE) { weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); } if (threadIdx.x < k) { output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; } } } inline __device__ unsigned int xorwow_moe(unsigned int &state) { state ^= state >> 7; state ^= state << 9; state ^= state >> 13; return state; } template __launch_bounds__(TPB) __global__ void moe_redundant_top_k_normed(const T* inputs_after_softmax, const T* bias, const int* expert_id_to_ep_rank_array, const int* expert_in_rank_num_list, int* tokens_per_expert_stats_list, T* output, IdxT* indices, IdxT* indices_tmp, int* source_rows, const int64_t num_experts, const int64_t k, const int64_t num_rows, const int redundant_ep_rank_num_plus_one) { using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; cub_kvp thread_kvp; cub::ArgMax arg_max; const int block_row = blockIdx.x + blockIdx.y * gridDim.x; // unsigned int state = block_row + blockIdx.x * blockDim.x + *kernel_call_num; unsigned int state = block_row + blockIdx.x * blockDim.x; if (block_row >= num_rows) { return; } const bool should_process_row = true; const int thread_read_offset = block_row * num_experts; T weight_sum = static_cast(0); extern __shared__ char smem[]; T* row_outputs = reinterpret_cast(smem); for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { const int idx = thread_read_offset + expert; inp_kvp.key = expert; inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { const int prior_winning_expert = indices_tmp[k * block_row + prior_k]; if (prior_winning_expert == expert) { inp_kvp = thread_kvp; } } thread_kvp = arg_max(inp_kvp, thread_kvp); } const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); if (threadIdx.x == 0) { const int idx = k * block_row + k_idx; // output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; source_rows[idx] = k_idx * num_rows + block_row; int expert_topk = should_process_row ? result_kvp.key : num_experts; // runduncy int len = expert_in_rank_num_list[expert_topk]; int select = (int)xorwow_moe(state) % len; int selected_rank = expert_id_to_ep_rank_array[expert_topk * redundant_ep_rank_num_plus_one + select]; indices[idx] = (IdxT)selected_rank; indices_tmp[idx] = result_kvp.key; atomicAdd(&tokens_per_expert_stats_list[result_kvp.key], 1); T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value; row_outputs[k_idx] = row_out; weight_sum += row_out; } __syncthreads(); } if (threadIdx.x < WARP_SIZE) { weight_sum = __shfl_sync(0xffffffff, weight_sum, 0); } if (threadIdx.x < k) { output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum; } } // ====================== TopK softmax things =============================== /* A Top-K gating softmax written to exploit when the number of experts in the MoE layers are a small power of 2. This allows us to cleanly share the rows among the threads in a single warp and eliminate communication between warps (so no need to use shared mem). It fuses the softmax, max and argmax into a single kernel. Limitations: 1) This implementation is intended for when the number of experts is a small power of 2. 2) This implementation assumes k is small, but will work for any k. */ template __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__ void topk_gating_softmax(const T* input, T* output, const int64_t num_rows, IdxT* indices, int* source_rows, const int64_t k) { // We begin by enforcing compile time assertions and setting up compile time // constants. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); // Number of bytes each thread pulls in per load static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); static constexpr int ELTS_PER_ROW = NUM_EXPERTS; static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; // Restrictions based on previous section. static_assert( VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); // We have NUM_EXPERTS elements per row. We specialize for small #experts static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; // Restrictions for previous section. static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); // ===================== From this point, we finally start computing run-time // variables. ======================== // Compute CTA and warp rows. We pack multiple rows into a single warp, and a // block contains WARPS_PER_CTA warps. This, each block processes a chunk of // rows. We start by computing the start row for each block. const int cta_base_row = blockIdx.x * ROWS_PER_CTA; // Now, using the base row per thread block, we compute the base row per warp. const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; // The threads in a warp are split into sub-groups that will work on a row. // We compute row offset for each thread sub-group const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; const int thread_row = warp_base_row + thread_row_in_warp; // Threads with indices out of bounds should early exit here. if (thread_row >= num_rows) return; const bool should_process_row = true; // We finally start setting up the read pointers for each thread. First, each // thread jumps to the start of the row it will read. const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the // first column to start loads. const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Determine the pointer type to use to read in the data depending on the // BYTES_PER_LDG template param. In theory, this can support all powers of 2 // up to 16. using AccessType = cutlass::AlignedArray; // Finally, we pull in the data from global mem cutlass::Array row_chunk_input; AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_input); const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; } using ComputeType = float; using Converter = cutlass::NumericArrayConverter; Converter compute_type_converter; cutlass::Array row_chunk = compute_type_converter(row_chunk_input); // First, we perform a max reduce within the thread. We can do the max in fp16 // safely (I think) and just convert to float afterwards for the exp + sum // reduction. ComputeType thread_max = row_chunk[0]; #pragma unroll for (int ii = 1; ii < VPT; ++ii) { thread_max = max(thread_max, row_chunk[ii]); } // Now, we find the max within the thread group and distribute among the // threads. We use a butterfly reduce. #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); } // From this point, thread max in all the threads have the max within the row. // Now, we subtract the max from each element in the thread and take the exp. // We also compute the thread local sum. float row_sum = 0; #pragma unroll for (int ii = 0; ii < VPT; ++ii) { row_chunk[ii] = expf(row_chunk[ii] - thread_max); row_sum += row_chunk[ii]; } // Now, we perform the sum reduce within each thread group. Similar to the max // reduce, we use a bufferfly pattern. #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); } // From this point, all threads have the max and the sum for their rows in the // thread_max and thread_sum variables respectively. Finally, we can scale the // rows for the softmax. Technically, for top-k gating we don't need to // compute the entire softmax row. We can likely look at the maxes and only // compute for the top-k values in the row. However, this kernel will likely // not be a bottle neck and it seems better to closer match torch and find the // argmax after computing the softmax. const float reciprocal_row_sum = 1.f / row_sum; #pragma unroll for (int ii = 0; ii < VPT; ++ii) { row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; } // Now, softmax_res contains the softmax of the row chunk. Now, I want to find // the topk elements in each row, along with the max index.​ int start_col = first_elt_read_by_thread; static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; for (int k_idx = 0; k_idx < k; ++k_idx) { // First, each thread does the local argmax float max_val = row_chunk[0]; int expert = start_col; #pragma unroll for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { #pragma unroll for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { float val = row_chunk[ldg * ELTS_PER_LDG + ii]; // No check on the experts here since columns with the smallest index // are processed first and only updated if > (not >=) if (val > max_val) { max_val = val; expert = col + ii; } } } // Now, we perform the argmax reduce. We use the butterfly pattern so threads // reach consensus about the max. This will be useful for K > 1 so that the // threads can agree on "who" had the max value. That thread can then blank out // their max with -inf and the warp can run more iterations... #pragma unroll for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); // We want lower indices to "win" in every thread so we break ties this // way if (other_max > max_val || (other_max == max_val && other_expert < expert)) { max_val = other_max; expert = other_expert; } } // Write the max for this k iteration to global memory. if (thread_group_idx == 0) { // The lead thread from each sub-group will write out the final results to // global memory. (This will be a single) thread per row of the // input/output matrices. const int idx = k * thread_row + k_idx; output[idx] = T(max_val); indices[idx] = should_process_row ? expert : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; } // Finally, we clear the value in the thread with the current max if there // is another iteration to run. if (k_idx + 1 < k) { const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; // Only the thread in the group which produced the max will reset the // "winning" value to -inf. if (thread_group_idx == thread_to_clear_in_group) { const int offset_for_expert = expert % ELTS_PER_LDG; // Safe to set to any negative value since row_chunk values must be // between 0 and 1. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = ComputeType(-10000.f); } } } } namespace detail { // Constructs some constants needed to partition the work across threads at // compile time. template struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; }; } // namespace detail template void topk_gating_softmax_launcher_helper(const T* input, T* output, IdxT* indices, int* source_row, const int64_t num_rows, const int64_t num_experts, const int64_t k, cudaStream_t stream) { static constexpr uint64_t MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); topk_gating_softmax <<>>( input, output, num_rows, indices, source_row, k); } template struct topk_gating_softmax_kernelLauncher{ static void run(const T* input, const T* gating_correction_bias, T* output, T* softmax, IdxT* indices, int* source_row, T* softmax_max_prob, const int64_t num_rows, const int64_t num_experts, const int64_t k, const bool group_moe, cudaStream_t stream, const bool topk_only_mode = false) { if (topk_only_mode) { static constexpr int TPB = 256; const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); moe_top_k<<>>( input, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows); return; } static constexpr int WARPS_PER_TB = 4; #define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \ case N: { \ topk_gating_softmax_launcher_helper( \ input, output, indices, source_row, num_rows, num_experts, k, stream); \ break; \ } int64_t tem_num_experts = num_experts; if(gating_correction_bias != nullptr) tem_num_experts = 0; switch (tem_num_experts) { LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(16) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(32) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(64) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(128) LAUNCH_TOPK_GATING_SOFTMAX_HELPER(256) default: { static constexpr int TPB = 256; if (group_moe) { const int group_experts = num_experts / k; const int softmax_num_rows = num_rows * k; const auto config_softmax = Get1DBlocksAnd2DGridsMoe(softmax_num_rows); group_moe_softmax <<>>( input, softmax, softmax_max_prob, group_experts, softmax_num_rows); const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); group_moe_top_k <<>>(softmax, output, indices, source_row, softmax_max_prob, num_experts, k, num_rows); } else { const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows); moe_softmax<<>>( input, softmax, num_experts, num_rows); moe_top_k <<>>(softmax, gating_correction_bias, output, indices, source_row, num_experts, k, num_rows); } } } } }; // ========================== Permutation things // ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation // map to help with finalizing routing. // "expanded_x_row" simply means that the number of values is num_rows x k. It // is "expanded" since we will have to duplicate some rows in the input matrix // to match the dimensions. Duplicates will always get routed to separate // experts in the end. // Note that the expanded_dest_row_to_expanded_source_row map referred to here // has indices in the range (0, k*rows_in_input - 1). However, it is set up so // that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input all map // to row 0 in the original matrix. Thus, to know where to read in the source // matrix, we simply take the modulus of the expanded index. template __global__ void initialize_moe_routing_kernel( const T* unpermuted_input, OutT* permuted_output, const int* expanded_dest_row_to_expanded_source_row, const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, const int64_t num_rows, const int64_t active_rows, const int64_t cols, const int64_t num_rows_k) { using LoadT = AlignedVector; LoadT src_vec; // Reverse permutation map. // I do this so that later, we can use the source -> dest map to do the k-way // reduction and unpermuting. I need the reverse map for that reduction to // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 // thread block will be responsible for all k summations. const int expanded_dest_row = blockIdx.x + blockIdx.y * gridDim.x; if (expanded_dest_row >= num_rows_k) return; const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; if (threadIdx.x == 0) { expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; } if (expanded_dest_row < active_rows) { const int expert_idx = expert_idx_per_token[expanded_dest_row]; const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1; const int source_row = expanded_source_row % num_rows; const T* source_row_ptr = unpermuted_input + source_row * cols; OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols; for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) { // dest_row_ptr[tid] = source_row_ptr[tid]; Load(&source_row_ptr[tid], &src_vec); if constexpr (std::is_same::value) { using StoreT = AlignedVector; StoreT dest_vec; const float max_bound = 127.f; const float min_bound = -127.f; for (int j = 0; j < VecSize; j++) { float quant_value = max_bound * scale * static_cast(src_vec[j]); quant_value = quant_value > max_bound ? max_bound : quant_value; quant_value = quant_value < min_bound ? min_bound : quant_value; dest_vec[j] = static_cast(round(quant_value)); } Store(dest_vec, &dest_row_ptr[tid]); } else { Store(src_vec, &dest_row_ptr[tid]); } } } } template struct initialize_moe_routing_kernelLauncher{ static void run( const T* unpermuted_input, OutT* permuted_output, const int* expanded_dest_row_to_expanded_source_row, const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, const int64_t num_rows, const int64_t active_rows, const int64_t cols, const int64_t k, cudaStream_t stream) { const int threads = std::min(cols, int64_t(1024)); constexpr int max_pack_size = 16 / sizeof(T); const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); if (cols % max_pack_size == 0) { initialize_moe_routing_kernel <<>>( unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, expert_idx_per_token, w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols, num_rows * k); } else { initialize_moe_routing_kernel <<>>( unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, expert_idx_per_token, w4a8_in_scale, expanded_source_row_to_expanded_dest_row, num_rows, k * active_rows, cols, num_rows * k); } } }; // ============================== Infer GEMM sizes // ================================= __device__ inline int find_total_elts_leq_target(int* sorted_indices, const int64_t arr_length, const int64_t target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) { int64_t mid = (low + high) / 2; if (sorted_indices[mid] > target) { high = mid - 1; } else { low = mid + 1; target_location = mid; } } return target_location + 1; } void compute_total_rows_before_expert(int* sorted_indices, const int64_t total_indices, const int64_t num_experts, int64_t* total_rows_before_expert, cudaStream_t stream); // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and // performs the final skip connection. template __global__ void finalize_moe_routing_kernel( const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int64_t cols, const int64_t k, const int64_t compute_bias, const bool norm_topk_prob, const float routed_scaling_factor, const int64_t num_rows) { const int original_row = blockIdx.x; auto const offset = original_row * cols; T* reduced_row_ptr = reduced_unpermuted_output + offset; constexpr int64_t FINALIZE_ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; int64_t const start_offset = threadIdx.x; int64_t const stride = FINALIZE_THREADS_PER_BLOCK; int64_t const num_elems_in_col = cols / FINALIZE_ELEM_PER_THREAD; using BiasElem = cutlass::Array; using InputElem = cutlass::Array; using OutputElem = cutlass::Array; using ComputeElem = cutlass::Array; using SharedOutputElem = cutlass::Array; auto const* bias_v = reinterpret_cast(bias); auto const* expanded_permuted_rows_v = reinterpret_cast(expanded_permuted_rows); auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); #pragma unroll for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); float row_rescale{0.f}; for (int k_idx = 0; k_idx < k; ++k_idx) { int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; int64_t const k_offset = original_row * k + k_idx; const float row_scale = scales[k_offset]; row_rescale = row_rescale + row_scale; auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; int const expert_idx = expert_for_source_row[k_offset]; auto const* bias_ptr = bias_v + expert_idx * num_elems_in_col; ComputeElem bias_value; if (bias) { bias_value = arrayConvert(bias_ptr[elem_index]); } else { bias_value.fill(0); } ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); thread_output = thread_output + row_scale * (expert_result + bias_value); } for (auto& elem : thread_output) { elem = elem / (norm_topk_prob ? row_rescale : 1.0f) * routed_scaling_factor; } OutputElem output_elem = arrayConvert(thread_output); reduced_row_ptr_v[elem_index] = output_elem; } } template struct finalize_moe_routing_kernelLauncher{ static void run( const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int64_t num_rows, const int64_t cols, const int64_t k, const int64_t compute_bias, const bool norm_topk_prob, const float routed_scaling_factor, cudaStream_t stream) { const int blocks = num_rows; const int threads = FINALIZE_THREADS_PER_BLOCK; finalize_moe_routing_kernel <<>>( expanded_permuted_rows, reduced_unpermuted_output, bias, scales, expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, compute_bias, norm_topk_prob, routed_scaling_factor, num_rows); } }; } // namespace phi