mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] clean code (#5691)
This commit is contained in:
@@ -547,7 +547,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
|
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
// calculate group_idx
|
// calculate group_idx
|
||||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
int32_t want_neg_inf_num = WARP_SIZE - n_group + topk_group;
|
||||||
if (lane_id < n_group &&
|
if (lane_id < n_group &&
|
||||||
(isfinite(cuda_cast<float, T>(
|
(isfinite(cuda_cast<float, T>(
|
||||||
group_scores[lane_id])))) // The check is necessary to avoid
|
group_scores[lane_id])))) // The check is necessary to avoid
|
||||||
@@ -556,20 +556,24 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
value = group_scores[lane_id];
|
value = group_scores[lane_id];
|
||||||
}
|
}
|
||||||
|
|
||||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
int neg_inf_num = WARP_SIZE - n_group;
|
||||||
int pre_count_equal_to_top_value = 0;
|
int last_neg_inf_num = 0;
|
||||||
// Use loop to find the largset top_group
|
// Use loop to find the largset top_group
|
||||||
while (count_equal_to_top_value < target_num_min) {
|
while (neg_inf_num < want_neg_inf_num) {
|
||||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||||
if (value == topk_group_value) {
|
if (value == topk_group_value) {
|
||||||
value = neg_inf<T>();
|
value = neg_inf<T>();
|
||||||
}
|
}
|
||||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
last_neg_inf_num = neg_inf_num;
|
||||||
count_equal_to_top_value =
|
|
||||||
|
neg_inf_num =
|
||||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||||
}
|
}
|
||||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
// There is a possible case:
|
||||||
|
// may have many different group holding the same score!
|
||||||
|
// but we only accept some of them!
|
||||||
|
num_equalto_topkth_group = want_neg_inf_num - last_neg_inf_num;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user