diff --git a/custom_ops/gpu_ops/per_token_quant_fp8.cu b/custom_ops/gpu_ops/per_token_quant_fp8.cu index 9a16d4d36..f195403a5 100644 --- a/custom_ops/gpu_ops/per_token_quant_fp8.cu +++ b/custom_ops/gpu_ops/per_token_quant_fp8.cu @@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input, max_value_thread = max(abs(load_vec_float[vid]), max_value_thread); } // get max value per warp - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread); - max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread); - // broadcast max_value - max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread); + max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread); max_value_thread = max(max_value_thread, epsilon); float scale_to_store = max_value_thread / MAX_VALUE; // quant