mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] clean code (#5691)
This commit is contained in:
@@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
@@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
int neg_inf_num = WARP_SIZE - n_group;
|
||||
int last_neg_inf_num = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
while (neg_inf_num < want_neg_inf_num) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value =
|
||||
last_neg_inf_num = neg_inf_num;
|
||||
|
||||
neg_inf_num =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
// There is a possible case:
|
||||
// may have many different group holding the same score!
|
||||
// but we only accept some of them!
|
||||
num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user