From 922a73ddd6f0220a121d4f206f4843425a60c4c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 24 Dec 2025 11:28:47 +0800 Subject: [PATCH] [Others] clean code (#5691) --- custom_ops/gpu_ops/noauxtc_kernel.h | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index c65ae1c5f..048890f59 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group; if (lane_id < n_group && (isfinite(cuda_cast( group_scores[lane_id])))) // The check is necessary to avoid @@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel( value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; - int pre_count_equal_to_top_value = 0; + int neg_inf_num = WARP_SIZE - n_group; + int last_neg_inf_num = 0; // Use loop to find the largset top_group - while (count_equal_to_top_value < target_num_min) { + while (neg_inf_num < want_neg_inf_num) { __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { value = neg_inf(); } - pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = + last_neg_inf_num = neg_inf_num; + + neg_inf_num = __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); } - num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + // There is a possible case: + // may have many different group holding the same score! + // but we only accept some of them! + num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num; } __syncthreads();