diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index c65ae1c5f..048890f59 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group; if (lane_id < n_group && (isfinite(cuda_cast( group_scores[lane_id])))) // The check is necessary to avoid @@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel( value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; - int pre_count_equal_to_top_value = 0; + int neg_inf_num = WARP_SIZE - n_group; + int last_neg_inf_num = 0; // Use loop to find the largset top_group - while (count_equal_to_top_value < target_num_min) { + while (neg_inf_num < want_neg_inf_num) { __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { value = neg_inf(); } - pre_count_equal_to_top_value = count_equal_to_top_value; - count_equal_to_top_value = + last_neg_inf_num = neg_inf_num; + + neg_inf_num = __popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf()))); } - num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; + // There is a possible case: + // may have many different group holding the same score! + // but we only accept some of them! + num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num; } __syncthreads();