[Metax] modify wrapSize to WARP_SIZE (#5442)

This commit is contained in:
xiaozude
2025-12-09 17:44:02 +08:00
committed by GitHub
parent e397c4fba6
commit df67379bc3
4 changed files with 406 additions and 228 deletions

View File

@@ -34,16 +34,16 @@ __global__ void PrefixSumKernel(int64_t *ids_remove_padding,
int cum_seq_len = 0;
// compute sum of seq_lens[0,1,2,...,bi]
for (int i = lane_id; i < bi + 1; i += warpSize) {
for (int i = lane_id; i < bi + 1; i += WARP_SIZE) {
cum_seq_len += seq_lens[i];
}
for (int offset = 1; offset < warpSize; offset <<= 1) {
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
const int tmp = __shfl_up_sync(0xffffffff, cum_seq_len, offset);
if (lane_id >= offset) cum_seq_len += tmp;
}
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, warpSize - 1);
cum_seq_len = __shfl_sync(0xffffffff, cum_seq_len, WARP_SIZE - 1);
if (tid == 0) {
cu_seqlens_q[bi + 1] = cum_seq_len;