mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 22:54:01 +08:00
[Optimization] Optimize split_q_block kernel (#4367)
This commit is contained in:
@@ -197,25 +197,55 @@ __global__ void split_q_block(const int *__restrict__ seq_lens_q,
|
|||||||
const int *__restrict__ seq_lens_encoder,
|
const int *__restrict__ seq_lens_encoder,
|
||||||
int *__restrict__ batch_ids,
|
int *__restrict__ batch_ids,
|
||||||
int *__restrict__ tile_ids_per_batch,
|
int *__restrict__ tile_ids_per_batch,
|
||||||
int *__restrict__ num_blocks_x, const int bsz,
|
int *__restrict__ num_blocks_x,
|
||||||
|
const int bsz,
|
||||||
const int num_rows_per_block,
|
const int num_rows_per_block,
|
||||||
const int group_size) {
|
const int group_size) {
|
||||||
if (threadIdx.x == 0) {
|
// one block one warp
|
||||||
int gridx = 0;
|
const int lane_id = threadIdx.x % warpSize;
|
||||||
int index = 0;
|
int prev_offset = 0;
|
||||||
for (uint32_t bid = 0; bid < bsz; bid++) {
|
|
||||||
|
// loop on warp tile:[base, base+32)
|
||||||
|
for (int base = 0; base < bsz; base += warpSize) {
|
||||||
|
const int bid = base + lane_id;
|
||||||
|
|
||||||
|
// calculate loop_times for bid
|
||||||
|
int loop_times = 0;
|
||||||
|
if (bid < bsz) {
|
||||||
int seq_len = seq_lens_q[bid];
|
int seq_len = seq_lens_q[bid];
|
||||||
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
if (seq_lens_encoder && seq_lens_encoder[bid] > 0) {
|
||||||
seq_len = 0;
|
seq_len = 0;
|
||||||
}
|
}
|
||||||
const int loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
loop_times = div_up(seq_len * group_size, num_rows_per_block);
|
||||||
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
|
|
||||||
batch_ids[index] = bid;
|
|
||||||
tile_ids_per_batch[index++] = tile_id;
|
|
||||||
}
|
|
||||||
gridx += loop_times;
|
|
||||||
}
|
}
|
||||||
*num_blocks_x = gridx;
|
|
||||||
|
// prefix sum for each lane, get the start offset in this tile
|
||||||
|
// inclusive scan
|
||||||
|
int x = loop_times;
|
||||||
|
for (int offset = 1; offset < warpSize; offset <<= 1) {
|
||||||
|
int y = __shfl_up_sync(0xffffffff, x, offset);
|
||||||
|
if (lane_id >= offset) x += y;
|
||||||
|
}
|
||||||
|
// exclusive prefix sum
|
||||||
|
int bid_offset = x - loop_times;
|
||||||
|
int tile_sum = __shfl_sync(0xffffffff, x, warpSize - 1);
|
||||||
|
|
||||||
|
// write batch_ids and tile_ids_per_batch
|
||||||
|
if (bid < bsz && loop_times > 0) {
|
||||||
|
int write_base = prev_offset + bid_offset;
|
||||||
|
for (int t = 0; t < loop_times; ++t) {
|
||||||
|
int pos = write_base + t;
|
||||||
|
batch_ids[pos] = bid;
|
||||||
|
tile_ids_per_batch[pos] = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for next warp tile
|
||||||
|
prev_offset += tile_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
*num_blocks_x = prev_offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user