This commit is contained in:
周周周
2025-10-31 21:25:11 +08:00
committed by GitHub
parent 27746026c1
commit 10358bf1a0

View File

@@ -493,7 +493,7 @@ __global__ void topk_with_k2_kernel(T* output,
template <typename T, typename IdxT>
__global__ void group_idx_and_topk_idx_kernel(
const T* scores,
T* scores,
T const* group_scores,
T* topk_values,
IdxT* topk_indices,
@@ -620,6 +620,13 @@ __global__ void group_idx_and_topk_idx_kernel(
}
__syncthreads();
// Note(ZKK): a little trick.
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i = lane_id; i < num_experts; i += WARP_SIZE) {
scores[i] = 0;
}
}
__syncwarp();
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
@@ -631,6 +638,7 @@ __global__ void group_idx_and_topk_idx_kernel(
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
scores[s_topk_idx[i]] = value;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = cuda_cast<T, float>(value);
}