[Others] clean code (#5691)

This commit is contained in:
周周周
2025-12-24 11:28:47 +08:00
committed by GitHub
parent 23d488c488
commit 922a73ddd6

View File

@@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group &&
(isfinite(cuda_cast<float, T>(
group_scores[lane_id])))) // The check is necessary to avoid
@@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel(
value = group_scores[lane_id];
}
int count_equal_to_top_value = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0;
int neg_inf_num = WARP_SIZE - n_group;
int last_neg_inf_num = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
while (neg_inf_num < want_neg_inf_num) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = neg_inf<T>();
}
pre_count_equal_to_top_value = count_equal_to_top_value;
count_equal_to_top_value =
last_neg_inf_num = neg_inf_num;
neg_inf_num =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
}
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
// There is a possible case:
// may have many different group holding the same score!
// but we only accept some of them!
num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num;
}
__syncthreads();