diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index 7758c3b64..a3a52051b 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -493,7 +493,7 @@ __global__ void topk_with_k2_kernel(T* output, template __global__ void group_idx_and_topk_idx_kernel( - const T* scores, + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, @@ -620,6 +620,13 @@ __global__ void group_idx_and_topk_idx_kernel( } __syncthreads(); + // Note(ZKK): a little trick. + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores[i] = 0; + } + } + __syncwarp(); if (case_id < num_tokens) { if (if_proceed_next_topk) { @@ -631,6 +638,7 @@ __global__ void group_idx_and_topk_idx_kernel( } else { value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; } + scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; topk_values[i] = cuda_cast(value); }