mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
topk_gating_softmax support bias (#3405)
This commit is contained in:
@@ -574,6 +574,7 @@ template <typename T,
|
||||
typename IdxT = int>
|
||||
__launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
void topk_gating_softmax(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
const int64_t num_rows,
|
||||
IdxT* indices,
|
||||
@@ -716,7 +717,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii) {
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
row_chunk[ii] = bias ? row_chunk[ii] * reciprocal_row_sum + bias[first_elt_read_by_thread + ii] : row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find
|
||||
@@ -765,6 +766,7 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
T final_val = bias ? T(max_val) - bias[expert] : T(max_val);
|
||||
if (thread_group_idx == 0) {
|
||||
// The lead thread from each sub-group will write out the final results to
|
||||
// global memory. (This will be a single) thread per row of the
|
||||
@@ -772,11 +774,11 @@ __launch_bounds__(WARPS_PER_CTA * WARP_SIZE) __global__
|
||||
const int idx = k * thread_row + k_idx;
|
||||
if constexpr (Norm_Weights) {
|
||||
const int idx_in_cta = k * thread_row_in_cta + k_idx;
|
||||
row_output[idx_in_cta] = T(max_val);
|
||||
weight_sum += T(max_val);
|
||||
row_output[idx_in_cta] = final_val;
|
||||
weight_sum += final_val;
|
||||
}
|
||||
else {
|
||||
output[idx] = T(max_val);
|
||||
output[idx] = final_val;
|
||||
}
|
||||
indices[idx] = should_process_row ? expert : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
@@ -831,6 +833,7 @@ struct TopkConstants {
|
||||
|
||||
template <typename T, int EXPERTS, int WARPS_PER_TB, bool Norm_Weights = false, typename IdxT = int>
|
||||
void topk_gating_softmax_launcher_helper(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_row,
|
||||
@@ -851,7 +854,7 @@ void topk_gating_softmax_launcher_helper(const T* input,
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_TB * ROWS_PER_WARP;
|
||||
topk_gating_softmax<T, VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, Norm_Weights>
|
||||
<<<num_blocks, block_dim, ROWS_PER_CTA * k * sizeof(T), stream>>>(
|
||||
input, output, num_rows, indices, source_row, k);
|
||||
input, bias, output, num_rows, indices, source_row, k);
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT = int>
|
||||
@@ -882,7 +885,7 @@ static void run(const T* input,
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
input, gating_correction_bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
|
||||
@@ -51,7 +51,7 @@ void moe_redundant_topk_select_kernel(const T* input,
|
||||
#define LAUNCH_TOPK_GATING_SOFTMAX_HELPER(N) \
|
||||
case N: { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
|
||||
@@ -47,17 +47,14 @@ void moe_topk_select_kernel(const T* input,
|
||||
case N: { \
|
||||
if (apply_norm_weight) { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, true>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
} else { \
|
||||
topk_gating_softmax_launcher_helper<T, N, WARPS_PER_TB, false>( \
|
||||
input, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
input, bias, output, indices, source_row, num_rows, num_experts, k, stream); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
int64_t tem_num_experts = num_experts;
|
||||
// when bias is not none, set tem_num_experts to 0 to follow the default branch
|
||||
if(bias != nullptr) tem_num_experts = 0;
|
||||
switch (tem_num_experts) {
|
||||
switch (num_experts) {
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(2)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(4)
|
||||
LAUNCH_TOPK_GATING_SOFTMAX_HELPER(8)
|
||||
|
||||
Reference in New Issue
Block a user