mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
use shfl_xor_sync to reduce redundant shfl broadcast
This commit is contained in:
@@ -50,13 +50,11 @@ __global__ void quant_per_token_per_block(const T *input,
|
||||
max_value_thread = max(abs(load_vec_float[vid]), max_value_thread);
|
||||
}
|
||||
// get max value per warp
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_down_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
// broadcast max_value
|
||||
max_value_thread = __shfl_sync(0xFFFFFFFF, max_value_thread, 0);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 16), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 8), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 4), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 2), max_value_thread);
|
||||
max_value_thread = max(__shfl_xor_sync(0xffffffff, max_value_thread, 1), max_value_thread);
|
||||
max_value_thread = max(max_value_thread, epsilon);
|
||||
float scale_to_store = max_value_thread / MAX_VALUE;
|
||||
// quant
|
||||
|
Reference in New Issue
Block a user