[Others] clean code (#5691)

This commit is contained in:
周周周
2025-12-24 11:28:47 +08:00
committed by GitHub
parent 23d488c488
commit 922a73ddd6

View File

@@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) { if (case_id < num_tokens) {
// calculate group_idx // calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group; int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group;
if (lane_id < n_group && if (lane_id < n_group &&
(isfinite(cuda_cast<float, T>( (isfinite(cuda_cast<float, T>(
group_scores[lane_id])))) // The check is necessary to avoid group_scores[lane_id])))) // The check is necessary to avoid
@@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel(
value = group_scores[lane_id]; value = group_scores[lane_id];
} }
int count_equal_to_top_value = WARP_SIZE - n_group; int neg_inf_num = WARP_SIZE - n_group;
int pre_count_equal_to_top_value = 0; int last_neg_inf_num = 0;
// Use loop to find the largset top_group // Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) { while (neg_inf_num < want_neg_inf_num) {
__syncwarp(); // Ensure all threads have valid data before reduction __syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>()); topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) { if (value == topk_group_value) {
value = neg_inf<T>(); value = neg_inf<T>();
} }
pre_count_equal_to_top_value = count_equal_to_top_value; last_neg_inf_num = neg_inf_num;
count_equal_to_top_value =
neg_inf_num =
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>()))); __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
} }
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; // There is a possible case:
// may have many different group holding the same score!
// but we only accept some of them!
num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num;
} }
__syncthreads(); __syncthreads();