mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
Refactor moe_topk_select op to use apply_norm_weight as a template parameter (#3345)
* Refactor moe_topk_select op to use apply_norm_weight as a template parameter * update test
This commit is contained in:
@@ -150,8 +150,61 @@ __launch_bounds__(TPB) __global__
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
__launch_bounds__(TPB) __global__ void group_moe_top_k(const T* inputs_after_softmax,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
@@ -208,60 +261,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax(const T* input,
|
||||
T* output,
|
||||
const int64_t num_cols,
|
||||
const int64_t num_rows) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB) {
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val =
|
||||
exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
@@ -284,6 +284,13 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
T weight_sum = static_cast<T>(0);
|
||||
T* row_outputs = nullptr;
|
||||
|
||||
if constexpr (NormWeights){
|
||||
extern __shared__ char smem[];
|
||||
row_outputs = reinterpret_cast<T*>(smem);
|
||||
}
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
@@ -296,7 +303,7 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * block_row + prior_k];
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
@@ -310,15 +317,31 @@ __launch_bounds__(TPB) __global__ void moe_top_k(const T* inputs_after_softmax,
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
if constexpr (NormWeights){
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
else{
|
||||
output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if constexpr (NormWeights){
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
if (threadIdx.x < k) {
|
||||
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
template <typename T, int TPB, bool NormWeights = false, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
@@ -356,165 +379,6 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_fused(const T* input,
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
T val = T(threadDataExp * normalizing_factor);
|
||||
|
||||
// top_k
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceP = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduceP::TempStorage tmpStorageP;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
if (threadIdx.x < num_experts) {
|
||||
cub_kvp inp_kvp;
|
||||
int expert = threadIdx.x;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? val + bias[expert] : val;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const IdxT prior_winning_expert = indices[k * globalIdx + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduceP(tmpStorageP).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_top_k_normed(const T* inputs_after_softmax,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int block_row = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (block_row >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const bool should_process_row = true;
|
||||
const int thread_read_offset = block_row * num_experts;
|
||||
T weight_sum = static_cast<T>(0);
|
||||
|
||||
extern __shared__ char smem[];
|
||||
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = T(-1.f); // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB) {
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = bias ? inputs_after_softmax[idx] + bias[expert] : inputs_after_softmax[idx] ;
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k) {
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert) {
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp =
|
||||
BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0) {
|
||||
const int idx = k * block_row + k_idx;
|
||||
// output[idx] = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
indices[idx] = should_process_row ? result_kvp.key : num_experts;
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
|
||||
T row_out = bias ? inputs_after_softmax[thread_read_offset + result_kvp.key]: result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < k) {
|
||||
output[k * block_row + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int TPB, typename IdxT = int>
|
||||
__launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* input,
|
||||
const T* bias,
|
||||
T* output,
|
||||
IdxT* indices,
|
||||
int* source_rows,
|
||||
const int64_t num_experts,
|
||||
const int64_t k,
|
||||
const int64_t num_rows) {
|
||||
// softmax
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
int globalIdx = blockIdx.x + blockIdx.y * gridDim.x;
|
||||
if (globalIdx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
const int64_t thread_row_offset = globalIdx * num_experts;
|
||||
const int64_t idx = thread_row_offset+threadIdx.x;
|
||||
|
||||
cub::Sum sum;
|
||||
|
||||
float threadData = (threadIdx.x < num_experts) ? static_cast<float>(input[idx]) :(-FLT_MAX);
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0) {
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float threadDataSub = threadData - float_max;
|
||||
float threadDataExp = exp(threadDataSub);
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadDataExp, sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
@@ -532,8 +396,11 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
T weight_sum = static_cast<T>(0);
|
||||
extern __shared__ char smem[];
|
||||
T* row_outputs = reinterpret_cast<T*>(smem);
|
||||
T* row_outputs = nullptr;
|
||||
if constexpr (NormWeights){
|
||||
extern __shared__ char smem[];
|
||||
row_outputs = reinterpret_cast<T*>(smem);
|
||||
}
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx) {
|
||||
thread_kvp.key = 0;
|
||||
@@ -560,22 +427,28 @@ __launch_bounds__(TPB) __global__ void moe_softmax_top_k_normed_fused(const T* i
|
||||
if (threadIdx.x == 0) {
|
||||
const int cur_idx = k * globalIdx + k_idx;
|
||||
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
|
||||
indices[cur_idx] = result_kvp.key;
|
||||
source_rows[cur_idx] = k_idx * num_rows + globalIdx;
|
||||
|
||||
if constexpr (NormWeights) {
|
||||
T row_out = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
row_outputs[k_idx] = row_out;
|
||||
weight_sum += row_out;
|
||||
}
|
||||
else {
|
||||
output[cur_idx] = bias ? (result_kvp.value - bias[result_kvp.key]) : result_kvp.value;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if constexpr (NormWeights) {
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < WARP_SIZE) {
|
||||
weight_sum = __shfl_sync(0xffffffff, weight_sum, 0);
|
||||
}
|
||||
|
||||
if (threadIdx.x < k) {
|
||||
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
if (threadIdx.x < k) {
|
||||
output[k * globalIdx + threadIdx.x] = row_outputs[threadIdx.x] / weight_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,7 +888,7 @@ static void run(const T* input,
|
||||
group_experts,
|
||||
softmax_num_rows);
|
||||
const auto config_topk = Get1DBlocksAnd2DGridsMoe(num_rows);
|
||||
moe_top_k<T, TPB>
|
||||
group_moe_top_k<T, TPB>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
output,
|
||||
indices,
|
||||
|
@@ -102,7 +102,7 @@ void moe_redundant_topk_select_kernel(const T* input,
|
||||
else {
|
||||
assert(k<=TPB);
|
||||
if (apply_norm_weight) {
|
||||
moe_softmax_top_k_normed_fused<T, TPB>
|
||||
moe_softmax_top_k_fused<T, TPB, true>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
@@ -112,7 +112,7 @@ void moe_redundant_topk_select_kernel(const T* input,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
moe_softmax_top_k_fused<T, TPB>
|
||||
moe_softmax_top_k_fused<T, TPB, false>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
|
@@ -68,7 +68,7 @@ void moe_topk_select_kernel(const T* input,
|
||||
moe_softmax<T, TPB><<<config_topk.block_per_grid, TPB, 0, stream>>>(
|
||||
input, softmax, num_experts, num_rows);
|
||||
if (apply_norm_weight) {
|
||||
moe_top_k_normed<T, TPB>
|
||||
moe_top_k<T, TPB, true>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(softmax,
|
||||
bias,
|
||||
output,
|
||||
@@ -78,7 +78,7 @@ void moe_topk_select_kernel(const T* input,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
moe_top_k<T, TPB>
|
||||
moe_top_k<T, TPB, false>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(softmax,
|
||||
bias,
|
||||
output,
|
||||
@@ -93,7 +93,7 @@ void moe_topk_select_kernel(const T* input,
|
||||
else {
|
||||
assert(k<=TPB);
|
||||
if (apply_norm_weight) {
|
||||
moe_softmax_top_k_normed_fused<T, TPB>
|
||||
moe_softmax_top_k_fused<T, TPB, true>
|
||||
<<<config_topk.block_per_grid, TPB, k * sizeof(T), stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
@@ -103,7 +103,7 @@ void moe_topk_select_kernel(const T* input,
|
||||
k,
|
||||
num_rows);
|
||||
} else {
|
||||
moe_softmax_top_k_fused<T, TPB>
|
||||
moe_softmax_top_k_fused<T, TPB, false>
|
||||
<<<config_topk.block_per_grid, TPB, 0, stream>>>(input,
|
||||
bias,
|
||||
output,
|
||||
|
Reference in New Issue
Block a user