mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
[Optimization] Optimize split_q_block kernel (#4367)
This commit is contained in:
@@ -197,25 +197,55 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
||||
const int *__restrict__ seq_lens_encoder,
|
||||
int *__restrict__ batch_ids,
|
||||
int *__restrict__ tile_ids_per_batch,
|
||||
int *__restrict__ num_blocks_x, const int bsz,
|
||||
int *__restrict__ num_blocks_x,
|
||||
const int bsz,
|
||||
const int num_rows_per_block,
|
||||
const int group_size) {
|
||||
if (threadIdx.x == 0) {
|
||||
int gridx = 0;
|
||||
int index = 0;
|
||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
||||
// one block one warp
|
||||
const int lane_id = threadIdx.x % warpSize;
|
||||
int prev_offset = 0;
|
||||
|
||||
// loop on warp tile:[base, base+32)
|
||||
for (int base = 0; base < bsz; base += warpSize) {
|
||||
const int bid = base + lane_id;
|
||||
|
||||
// calculate loop_times for bid
|
||||
int loop_times = 0;
|
||||
if (bid < bsz) {
|
||||
int seq_len = seq_lens_q[bid];
|
||||
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
||||
seq_len = 0;
|
||||
}
|
||||
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
||||
batch_ids[index] = bid;
|
||||
tile_ids_per_batch[index++] = tile_id;
|
||||
}
|
||||
gridx += loop_times;
|
||||
loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
||||
}
|
||||
*num_blocks_x = gridx;
|
||||
|
||||
// prefix sum for each lane, get the start offset in this tile
|
||||
// inclusive scan
|
||||
int x = loop_times;
|
||||
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
||||
int y = __shfl_up_sync(0xffffffff, x, offset);
|
||||
if (lane_id >= offset) x += y;
|
||||
}
|
||||
// exclusive prefix sum
|
||||
int bid_offset = x - loop_times;
|
||||
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
|
||||
|
||||
// write batch_ids and tile_ids_per_batch
|
||||
if (bid < bsz && loop_times > 0) {
|
||||
int write_base = prev_offset + bid_offset;
|
||||
for (int t = 0; t < loop_times; ++t) {
|
||||
int pos = write_base + t;
|
||||
batch_ids[pos] = bid;
|
||||
tile_ids_per_batch[pos] = t;
|
||||
}
|
||||
}
|
||||
|
||||
// for next warp tile
|
||||
prev_offset += tile_sum;
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
*num_blocks_x = prev_offset;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user