mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
This commit is contained in:
@@ -24,7 +24,7 @@ __global__ void decode_absorb_cache_kernel(
|
||||
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -50,7 +50,7 @@ __global__ void decode_absorb_cache_kernel(
|
||||
linear_index += step) {
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||
const int write_seq_id = seq_lens[ori_bi];
|
||||
|
||||
@@ -96,7 +96,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
@@ -124,7 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
const int write_seq_id =
|
||||
seq_lens[ori_bi] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) continue;
|
||||
@@ -143,7 +143,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
|
||||
ori_bi,
|
||||
seq_lens[ori_bi],
|
||||
token_id,
|
||||
cum_offsets[ori_bi]);
|
||||
cu_seqlens_q[ori_bi]);
|
||||
}
|
||||
if (bias < nope_hidden_size) { // pe
|
||||
const uint32_t inner_bias = bias;
|
||||
@@ -179,7 +179,7 @@ __global__ void prefill_absorb_cache_kernel(
|
||||
// nope_size]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ padding_offsets,
|
||||
const int* __restrict__ cum_offsets,
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
const int max_seq_len,
|
||||
|
||||
Reference in New Issue
Block a user