use shfl_xor_sync to reduce redundant shfl broadcast

This commit is contained in:
MARD1NO
2025-06-30 13:12:21 +08:00
parent 90a5b18742
commit ac5f860536

View File

@@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input,
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
}
// get max value per warp
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
// broadcast max_value
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
max_value_thread = max(max_value_thread, epsilon);
float scale_to_store = max_value_thread / MAX_VALUE;
// quant