[Optimization] Optimize split_q_block kernel (#4367)

This commit is contained in:
Sunny-bot1
2025-10-15 11:28:00 +08:00
committed by GitHub
parent c4f866c457
commit 6d0cc0dd9c

View File

@@ -197,25 +197,55 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
int *__restrict__ num_blocks_x,
const int bsz,
const int num_rows_per_block,
const int group_size) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
// one block one warp
const int lane_id = threadIdx.x % warpSize;
int prev_offset = 0;
// loop on warp tile[base, base+32)
for (int base = 0; base < bsz; base += warpSize) {
const int bid = base + lane_id;
// calculate loop_times for bid
int loop_times = 0;
if (bid < bsz) {
int seq_len = seq_lens_q[bid];
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
seq_len = 0;
}
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
loop_times = div_up(seq_len * group_size, num_rows_per_block);
}
*num_blocks_x = gridx;
// prefix sum for each lane, get the start offset in this tile
// inclusive scan
int x = loop_times;
for (int offset = 1; offset < warpSize; offset <<= 1) {
int y = __shfl_up_sync(0xffffffff, x, offset);
if (lane_id >= offset) x += y;
}
// exclusive prefix sum
int bid_offset = x - loop_times;
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
// write batch_ids and tile_ids_per_batch
if (bid < bsz && loop_times > 0) {
int write_base = prev_offset + bid_offset;
for (int t = 0; t < loop_times; ++t) {
int pos = write_base + t;
batch_ids[pos] = bid;
tile_ids_per_batch[pos] = t;
}
}
// for next warp tile
prev_offset += tile_sum;
}
if (threadIdx.x == 0) {
*num_blocks_x = prev_offset;
}
}