From ac5f86053614085ed9103c86f2bc129fc545aa2c Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Mon, 30 Jun 2025 13:12:21 +0800 Subject: [PATCH] use shfl_xor_sync to reduce redundant shfl broadcast --- custom_ops/gpu_ops/per_token_quant_fp8.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index 9a16d4d36..f195403a5 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input, max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); } // get max value per warp - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread); - // broadcast max_value - max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread); max_value_thread = max(max_value_thread, epsilon); float scale_to_store = max_value_thread / MAX_VALUE; // quant