mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
revise get_moe_scores (#3164)
This commit is contained in:
@@ -372,10 +372,12 @@ __global__ void topk_with_k2_kernel(T* output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename IdxT>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores,
|
||||
T const* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const n_group,
|
||||
@@ -391,6 +393,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
@@ -436,6 +440,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
@@ -490,13 +495,23 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
|
||||
scores[s_topk_idx[i]] = value;
|
||||
if (if_proceed_next_topk) {
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = static_cast<T>(value);
|
||||
}
|
||||
else {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = static_cast<float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores,
|
||||
T* group_scores,
|
||||
T* topk_values,
|
||||
IdxT* topk_indices,
|
||||
T* scores_with_bias,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_experts,
|
||||
@@ -526,6 +541,8 @@ void invokeNoAuxTc(T* scores,
|
||||
dynamic_smem_in_bytes,
|
||||
stream>>>(scores,
|
||||
group_scores,
|
||||
topk_values,
|
||||
topk_indices,
|
||||
scores_with_bias,
|
||||
num_tokens,
|
||||
n_group,
|
||||
@@ -536,9 +553,11 @@ void invokeNoAuxTc(T* scores,
|
||||
routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T) \
|
||||
template void invokeNoAuxTc<T>(T * scores, \
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>(T * scores, \
|
||||
T * group_scores, \
|
||||
T* topk_values, \
|
||||
IdxT* topk_indices, \
|
||||
T * scores_with_bias, \
|
||||
int64_t const num_tokens, \
|
||||
int64_t const num_experts, \
|
||||
@@ -548,4 +567,4 @@ void invokeNoAuxTc(T* scores,
|
||||
double const routed_scaling_factor, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float);
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
|
Reference in New Issue
Block a user