mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] modify wrapSize to WARP_SIZE (#5442)
This commit is contained in:
@@ -34,16 +34,16 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
|
||||
int cum_seq_len = 0;
|
||||
|
||||
// compute sum of seq_lens[0,1,2,...,bi]
|
||||
for (int i = lane_id; i < bi + 1; i += warpSize) {
|
||||
for (int i = lane_id; i < bi + 1; i += WARP_SIZE) {
|
||||
cum_seq_len += seq_lens[i];
|
||||
}
|
||||
|
||||
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset);
|
||||
if (lane_id >= offset) cum_seq_len += tmp;
|
||||
}
|
||||
|
||||
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1);
|
||||
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1);
|
||||
|
||||
if (tid == 0) {
|
||||
cu_seqlens_q[bi + 1] = cum_seq_len;
|
||||
|
||||
Reference in New Issue
Block a user