From 10358bf1a0105ef9f5e3401e538bac996f949de5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Fri, 31 Oct 2025 21:25:11 +0800 Subject: [PATCH] fix noaux (#4731) --- custom_ops/gpu_ops/noauxtc_kernel.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index 7758c3b64..a3a52051b 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -493,7 +493,7 @@ __global__ void topk_with_k2_kernel(T* output, template __global__ void group_idx_and_topk_idx_kernel( - const T* scores, + T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, @@ -620,6 +620,13 @@ __global__ void group_idx_and_topk_idx_kernel( } __syncthreads(); + // Note(ZKK): a little trick. + if (case_id < num_tokens && if_proceed_next_topk) { + for (int i = lane_id; i < num_experts; i += WARP_SIZE) { + scores[i] = 0; + } + } + __syncwarp(); if (case_id < num_tokens) { if (if_proceed_next_topk) { @@ -631,6 +638,7 @@ __global__ void group_idx_and_topk_idx_kernel( } else { value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; } + scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; topk_values[i] = cuda_cast(value); }