revise get_moe_scores (#3164)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:43:07 +08:00
committed by GitHub
parent e24929efa3
commit af543b7f0f
6 changed files with 165 additions and 23 deletions

View File

@@ -33,10 +33,14 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
auto input_type = scores_with_bias.dtype();
auto place = scores_with_bias.place();
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place);
auto stream = scores_with_bias.stream();
invokeNoAuxTc<float>(reinterpret_cast<float*>(scores.data<float>()),
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()),
reinterpret_cast<float*>(group_scores.data<float>()),
reinterpret_cast<float*>(topk_values.data<float>()),
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()),
reinterpret_cast<float*>(scores_with_bias.data<float>()),
num_tokens,
num_experts,
@@ -46,19 +50,23 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
routed_scaling_factor,
stream);
return {scores};
return {scores, topk_values, topk_indices};
}
std::vector<paddle::DataType> NoauxTcInferDtype(
const paddle::DataType& scores_dtype,
const paddle::DataType& scores_with_bias_dtype) {
return {scores_dtype};
return {scores_dtype, scores_dtype, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> NoauxTcInferShape(
const std::vector<int64_t>& scores_shape,
const std::vector<int64_t>& gating_output_shape) {
return {scores_shape};
const std::vector<int64_t>& ,
const int topk) {
auto num_tokens = scores_shape[0];
auto topk_values_shape = std::vector<int64_t>{num_tokens, topk};
auto topk_indices_shape = std::vector<int64_t>{num_tokens, topk};
return {scores_shape, topk_values_shape, topk_indices_shape};
}
PD_BUILD_STATIC_OP(noaux_tc)

View File

@@ -372,10 +372,12 @@ __global__ void topk_with_k2_kernel(T* output,
}
}
template <typename T>
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const n_group,
@@ -391,6 +393,8 @@ __global__ void group_idx_and_topk_idx_kernel(
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
@@ -436,6 +440,7 @@ __global__ void group_idx_and_topk_idx_kernel(
queue((int32_t)topk, cuda::std::numeric_limits<T>::min());
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits<T>::min());
if (case_id < num_tokens) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
@@ -490,13 +495,23 @@ __global__ void group_idx_and_topk_idx_kernel(
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value = s_topk_value[i] / topk_sum * routed_scaling_factor;
scores[s_topk_idx[i]] = value;
if (if_proceed_next_topk) {
topk_indices[i] = s_topk_idx[i];
topk_values[i] = static_cast<T>(value);
}
else {
topk_indices[i] = i;
topk_values[i] = static_cast<float>(1.0f / topk);
}
}
}
}
template <typename T>
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores,
T* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int64_t const num_tokens,
int64_t const num_experts,
@@ -526,6 +541,8 @@ void invokeNoAuxTc(T* scores,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
@@ -536,9 +553,11 @@ void invokeNoAuxTc(T* scores,
routed_scaling_factor);
}
#define INSTANTIATE_NOAUX_TC(T) \
template void invokeNoAuxTc<T>(T * scores, \
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
T* topk_values, \
IdxT* topk_indices, \
T * scores_with_bias, \
int64_t const num_tokens, \
int64_t const num_experts, \
@@ -548,4 +567,4 @@ void invokeNoAuxTc(T* scores,
double const routed_scaling_factor, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float);
INSTANTIATE_NOAUX_TC(float, int32_t);