[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
This commit is contained in:
周周周
2025-07-16 20:10:57 +08:00
committed by GitHub
parent 42b80182e0
commit aa76085d1f
47 changed files with 237 additions and 260 deletions

View File

@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[bid];