[Feature] Support noaux for eplb (#5143)

* support noaux eplb

* noaux_eplb

* noaux_eplb

* noaux_eplb
This commit is contained in:
xiaoxiaohehe001
2025-11-21 14:10:32 +08:00
committed by GitHub
parent e70e2279ce
commit 6ca2651995
8 changed files with 616 additions and 23 deletions

View File

@@ -420,6 +420,13 @@ class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
}; // end class WarpSelect
} // namespace warp_topk
inline __device__ unsigned int xorwow_moe(unsigned int& state) {
state ^= state >> 7;
state ^= state << 9;
state ^= state >> 13;
return state;
}
template <typename T>
__device__ void topk_with_k2(T* output,
T const* input,
@@ -656,6 +663,195 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif
}
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_redundant_kernel(
T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
bool const renormalize,
int64_t const num_experts,
int64_t const num_experts_per_group,
double routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
unsigned int state = case_id;
scores_with_bias += case_id * num_experts;
scores += case_id * num_experts;
group_scores += case_id * n_group;
topk_values += case_id * topk;
topk_indices += case_id * topk;
int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
// store the target topk idx
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
T* s_topk_value =
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
warp_id * topk;
s_topk_idx += warp_id * topk;
T value = neg_inf<T>();
T topk_group_value = neg_inf<T>();
int32_t num_equalto_topkth_group;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
// acqbulk because it's ptr arithmetic
#endif
if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group &&
(isfinite(cuda_cast<float, T>(
group_scores[lane_id])))) // The check is necessary to avoid
// abnormal input
{
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
}
__syncthreads();
warp_topk::WarpSelect</*capability*/ WARP_SIZE,
/*greater*/ true,
T,
int32_t,
/* is_stable */ true>
queue((int32_t)topk, neg_inf<T>());
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = (topk_group_value != neg_inf<T>());
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
int32_t offset = i_group * num_experts_per_group;
for (int32_t i = lane_id; i < align_num_experts_per_group;
i += WARP_SIZE) {
T candidates =
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
scores_with_bias[offset + i]))
? scores_with_bias[offset + i]
: neg_inf<T>();
queue.add(candidates, offset + i);
}
if (group_scores[i_group] == topk_group_value) {
count_equalto_topkth_group++;
}
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value
// Calculate the summation
float topk_sum = 1e-20;
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id;
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
i += WARP_SIZE) {
T value = i < topk ? scores[s_topk_idx[i]]
: 0.0f; // Load the valid value of expert
if (i < topk) {
s_topk_value[i] = value;
}
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
__syncthreads();
// Note(ZKK): a little trick.
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__syncwarp();
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value;
if (renormalize) {
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
routed_scaling_factor;
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
scores[s_topk_idx[i]] = value;
int expert_topk = s_topk_idx[i];
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
// int select = 0;
int selected_rank =
expert_id_to_ep_rank_array[expert_topk *
redundant_ep_rank_num_plus_one +
select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = cuda_cast<T, float>(value);
}
} else {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
int expert_topk = i;
int len = expert_in_rank_num_list[expert_topk];
int select = (int)xorwow_moe(state) % len;
// int select = 0;
int selected_rank =
expert_id_to_ep_rank_array[expert_topk *
redundant_ep_rank_num_plus_one +
select];
atomicAdd(&tokens_per_expert_stats_list[expert_topk], 1);
topk_indices[i] = (IdxT)selected_rank;
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
}
}
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
// default result.
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores,
T* group_scores,
@@ -752,6 +948,111 @@ void invokeNoAuxTc(T* scores,
#endif
}
template <typename T, typename IdxT>
void invokeNoAuxTcRedundant(T* scores,
T* group_scores,
T* topk_values,
IdxT* topk_indices,
T* scores_with_bias,
int32_t* expert_id_to_ep_rank_array,
int32_t* expert_in_rank_num_list,
int32_t* tokens_per_expert_stats_list,
int64_t const num_tokens,
int64_t const num_experts,
int64_t const n_group,
int64_t const topk_group,
int64_t const topk,
bool const renormalize,
double const routed_scaling_factor,
int64_t const redundant_ep_rank_num_plus_one,
cudaStream_t const stream) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
topk_with_k2_kernel<T><<<topk_with_k2_num_blocks, BLOCK_SIZE, 0, stream>>>(
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#else
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
kernel_instance1,
group_scores,
scores_with_bias,
num_cases,
n_group,
num_experts / n_group);
#endif
int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU
group_idx_and_topk_idx_redundant_kernel<T>
<<<topk_with_k_group_num_blocks,
BLOCK_SIZE,
dynamic_smem_in_bytes,
stream>>>(scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
expert_id_to_ep_rank_array,
expert_in_rank_num_list,
tokens_per_expert_stats_list,
num_tokens,
n_group,
topk_group,
topk,
renormalize,
num_experts,
num_experts / n_group,
routed_scaling_factor,
redundant_ep_rank_num_plus_one);
#else
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes;
config.stream = stream;
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = false;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
kernel_instance2,
scores,
group_scores,
topk_values,
topk_indices,
scores_with_bias,
num_tokens,
n_group,
topk_group,
topk,
num_experts,
num_experts / n_group,
renormalize,
routed_scaling_factor);
#endif
}
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
template void invokeNoAuxTc<T, IdxT>(T * scores, \
T * group_scores, \
@@ -768,3 +1069,25 @@ void invokeNoAuxTc(T* scores,
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC(float, int32_t);
#define INSTANTIATE_NOAUX_TC_Redundant(T, IdxT) \
template void invokeNoAuxTcRedundant<T, IdxT>( \
T * scores, \
T * group_scores, \
T * topk_values, \
IdxT * topk_indices, \
T * scores_with_bias, \
int32_t * expert_id_to_ep_rank_array, \
int32_t * expert_in_rank_num_list, \
int32_t * tokens_per_expert_stats_list, \
int64_t const num_tokens, \
int64_t const num_experts, \
int64_t const n_group, \
int64_t const topk_group, \
int64_t const topk, \
bool const renormalize, \
double const routed_scaling_factor, \
int64_t const redundant_ep_rank_num_plus_one, \
cudaStream_t const stream);
INSTANTIATE_NOAUX_TC_Redundant(float, int32_t);