From 6d0cc0dd9c568d3a794e7f585b9972a75bc4b0cd Mon Sep 17 00:00:00 2001 From: Sunny-bot1 <68891411+Sunny-bot1@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:28:00 +0800 Subject: [PATCH] [Optimization] Optimize split_q_block kernel (#4367) --- .../get_block_shape_and_split_kv_block.cu | 54 ++++++++++++++----- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu index d17316fec..9451a521e 100644 --- a/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu +++ b/custom_ops/gpu_ops/append_attn/get_block_shape_and_split_kv_block.cu @@ -197,25 +197,55 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q, const int *__restrict__ seq_lens_encoder, int *__restrict__ batch_ids, int *__restrict__ tile_ids_per_batch, - int *__restrict__ num_blocks_x, const int bsz, + int *__restrict__ num_blocks_x, + const int bsz, const int num_rows_per_block, const int group_size) { - if (threadIdx.x == 0) { - int gridx = 0; - int index = 0; - for (uint32_t bid = 0; bid < bsz; bid++) { + // one block one warp + const int lane_id = threadIdx.x % warpSize; + int prev_offset = 0; + + // loop on warp tile:[base, base+32) + for (int base = 0; base < bsz; base += warpSize) { + const int bid = base + lane_id; + + // calculate loop_times for bid + int loop_times = 0; + if (bid < bsz) { int seq_len = seq_lens_q[bid]; if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { seq_len = 0; } - const int loop_times = div_up(seq_len * group_size, num_rows_per_block); - for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) { - batch_ids[index] = bid; - tile_ids_per_batch[index++] = tile_id; - } - gridx += loop_times; + loop_times = div_up(seq_len * group_size, num_rows_per_block); } - *num_blocks_x = gridx; + + // prefix sum for each lane, get the start offset in this tile + // inclusive scan + int x = loop_times; + for (int offset = 1; offset < warpSize; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (lane_id >= offset) x += y; + } + // exclusive prefix sum + int bid_offset = x - loop_times; + int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1); + + // write batch_ids and tile_ids_per_batch + if (bid < bsz && loop_times > 0) { + int write_base = prev_offset + bid_offset; + for (int t = 0; t < loop_times; ++t) { + int pos = write_base + t; + batch_ids[pos] = bid; + tile_ids_per_batch[pos] = t; + } + } + + // for next warp tile + prev_offset += tile_sum; + } + + if (threadIdx.x == 0) { + *num_blocks_x = prev_offset; } }