topk_gating_softmax support bias (#3405)

This commit is contained in:
Sunny-bot1
2025-08-15 11:57:45 +08:00
committed by GitHub
parent 5a84324798
commit 6c1f3ff897
3 changed files with 13 additions and 13 deletions

View File

@@ -574,6 +574,7 @@ template <typename T,
typename IdxT = int>
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
void topk_gating_softmax(const T* input,
const T* bias,
T* output,
const int64_t num_rows,
IdxT* indices,
@@ -716,7 +717,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
}
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
@@ -765,6 +766,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
}
// Write the max for this k iteration to global memory.
T final_val = bias ? T(max_val) - bias[expert] : T(max_val);
if (thread_group_idx == 0) {
// The lead thread from each sub-group will write out the final results to
// global memory. (This will be a single) thread per row of the
@@ -772,11 +774,11 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
const int idx = k * thread_row + k_idx;
if constexpr (Norm_Weights) {
const int idx_in_cta = k * thread_row_in_cta + k_idx;
row_output[idx_in_cta] = T(max_val);
weight_sum += T(max_val);
row_output[idx_in_cta] = final_val;
weight_sum += final_val;
}
else {
output[idx] = T(max_val);
output[idx] = final_val;
}
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
source_rows[idx] = k_idx * num_rows + thread_row;
@@ -831,6 +833,7 @@ struct TopkConstants {
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
void topk_gating_softmax_launcher_helper(const T* input,
const T* bias,
T* output,
IdxT* indices,
int* source_row,
@@ -851,7 +854,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
input, output, num_rows, indices, source_row, k);
input, bias, output, num_rows, indices, source_row, k);
}
template <typename T, typename IdxT = int>
@@ -882,7 +885,7 @@ static void run(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;

View File

@@ -51,7 +51,7 @@ void moe_redundant_topk_select_kernel(const T* input,
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
case N: { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
break; \
}
int64_t tem_num_experts = num_experts;

View File

@@ -47,17 +47,14 @@ void moe_topk_select_kernel(const T* input,
case N: { \
if (apply_norm_weight) { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} else { \
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
input, output, indices, source_row, num_rows, num_experts, k, stream); \
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
} \
break; \
}
int64_t tem_num_experts = num_experts;
// when bias is not none, set tem_num_experts to 0 to follow the default branch
if(bias != nullptr) tem_num_experts = 0;
switch (tem_num_experts) {
switch (num_experts) {
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)