[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

[Attention] remove cum_offsets from atten, and use cu_seqlens_q (#2870)
This commit is contained in:
周周周
2025-07-16 20:10:57 +08:00
committed by GitHub
parent 42b80182e0
commit aa76085d1f
47 changed files with 237 additions and 260 deletions

View File

@@ -47,7 +47,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -166,7 +166,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
lambda_tile_ids_per_batch,
@@ -203,7 +203,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_encoder,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids_per_batch,
@@ -275,7 +275,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -298,7 +298,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -323,7 +323,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -347,7 +347,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
rotary_embs,
qkv_out_scales,
@@ -404,7 +404,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -462,7 +462,7 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = seq_lens_this_time.dims()[0];
auto dispatch_by_template = [&](auto temp_args) -> std::vector<paddle::Tensor> {
return AppendAttentionKernel<type2value<decltype(temp_args)>::value>(
@@ -474,7 +474,7 @@ std::vector<paddle::Tensor> AppendAttention(
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -551,7 +551,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -611,7 +611,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -689,7 +689,7 @@ PD_BUILD_STATIC_OP(append_attention)
"seq_lens_decoder",
"seq_lens_this_time",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -114,8 +114,7 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -405,7 +404,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -477,8 +476,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -776,7 +774,7 @@ void MultiQueryAppendAttention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -882,7 +880,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -939,7 +937,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -974,7 +972,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1103,7 +1101,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1171,7 +1169,7 @@ void MultiQueryAppendAttention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1207,7 +1205,7 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1290,7 +1288,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1353,7 +1351,7 @@ void CascadeAppendAttentionC16Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -144,8 +144,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -504,7 +503,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -601,8 +600,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM / 2;
const uint32_t kv_b_stride = HEAD_DIM / 2;
const uint32_t kv_d_stride = BLOCK_SIZE / 2;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -963,7 +961,7 @@ void MultiQueryAppendC4Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1088,7 +1086,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1151,7 +1149,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1186,7 +1184,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1333,7 +1331,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1409,7 +1407,7 @@ void MultiQueryAppendC4Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1444,7 +1442,7 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1527,7 +1525,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1594,7 +1592,7 @@ void CascadeAppendAttentionC4Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -151,8 +151,7 @@ __global__ void multi_query_append_attention_c8_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block =
(tile_id * NUM_WARPS + wid) * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
@@ -473,7 +472,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids_per_batch,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -575,8 +574,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t kv_b_stride = HEAD_DIM;
const uint32_t kv_d_stride = BLOCK_SIZE;
const uint32_t q_start_seq_id =
batch_id * max_seq_len - __ldg(&cum_offsets[batch_id]);
const uint32_t q_start_seq_id = cu_seqlens_q[batch_id];
const uint32_t q_base_seq_id_this_block = tile_id * num_frags_x * 16;
const uint32_t q_offset = q_start_seq_id * q_ori_n_stride +
q_head_idx * HEAD_DIM +
@@ -900,7 +898,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1054,7 +1052,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1111,7 +1109,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1146,7 +1144,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1317,7 +1315,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1387,7 +1385,7 @@ void MultiQueryAppendC8Attention(
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -1417,7 +1415,7 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1500,7 +1498,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -1565,7 +1563,7 @@ void CascadeAppendAttentionC8Kernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -2111,7 +2111,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
@@ -2127,7 +2127,7 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
int seq_len_kv = seq_lens_kv[bid];

View File

@@ -41,7 +41,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -131,7 +131,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -176,7 +176,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -212,7 +212,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -247,7 +247,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -282,7 +282,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,
@@ -317,7 +317,7 @@ void CascadeAppendAttentionKernel(
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_table,
batch_ids,
tile_ids_per_batch,

View File

@@ -35,7 +35,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
const T * __restrict__ multi_d, // [bsz, num_chunks, num_heads]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const int * __restrict__ cu_seqlens_q,
const T * __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT * __restrict__ out, // [token_num, num_heads, head_dim]
@@ -59,7 +59,7 @@ __global__ void merge_varlen_multi_chunks_v2_kernel(const T * __restrict__ multi
__shared__ T smem[bdy * HEAD_DIM];
__shared__ T md_smem[bdy * 2];
const int start_token_ids = qid * max_seq_len - __ldg(&cum_offsets[qid]);
const int start_token_ids = cu_seqlens_q[qid];
using LoadT = AlignedVector<T, vec_size>;
LoadT load_vec;
LoadT res_vec;
@@ -134,7 +134,7 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
const T * __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int * __restrict__ seq_lens_q,
const int * __restrict__ seq_lens_kv,
const int * __restrict__ cum_offsets,
const int * __restrict__ cu_seqlens_q,
const int * __restrict__ block_table, // [bsz, block_num_per_seq]
const int max_seq_len,
const int max_dec_len,
@@ -171,8 +171,8 @@ __global__ void multi_query_decode_attention_kernel(T * __restrict__ q, // [toke
}
kv_len += q_len;
const uint32_t num_chunk_this_seq = div_up(kv_len, chunk_size);
const uint32_t q_start_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const uint32_t q_write_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const uint32_t q_start_idx = cu_seqlens_q[bid];
const uint32_t q_write_idx = cu_seqlens_q[bid];
if (chunk_id >= num_chunk_this_seq) {
return;
}
@@ -318,7 +318,7 @@ void MultiQueryDecoderAttention(
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const int max_seq_len,
const int max_dec_len,
@@ -393,7 +393,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -430,7 +430,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_dec_len,
@@ -456,7 +456,7 @@ void MultiQueryDecoderAttention(
reinterpret_cast<NV_TYPE*>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(shift_bias_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(smooth_weight_ptr)),
reinterpret_cast<NV_TYPE*>(const_cast<T*>(out->data<T>())),
@@ -484,7 +484,7 @@ void DecodeMLAAttentionKernel(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
@@ -513,7 +513,7 @@ void DecodeMLAAttentionKernel(
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cum_offsets,
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
@@ -528,7 +528,7 @@ template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,
@@ -549,7 +549,7 @@ template void DecodeMLAAttentionKernel<paddle::float16>(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,

View File

@@ -29,7 +29,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -65,7 +65,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -135,7 +135,7 @@ __global__ void append_decode_cache_T_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -177,7 +177,7 @@ __global__ void append_decode_cache_T_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -255,7 +255,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -293,7 +293,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -367,7 +367,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -409,7 +409,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
@@ -499,7 +499,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -523,7 +523,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -746,7 +746,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -775,7 +775,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1048,7 +1048,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1073,7 +1073,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1347,7 +1347,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1377,7 +1377,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -1740,7 +1740,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1766,7 +1766,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2035,7 +2035,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2066,7 +2066,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2363,7 +2363,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2389,7 +2389,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
@@ -2733,7 +2733,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -2764,7 +2764,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int half_block_size = block_size / 2;
const int start_token_idx = bid * max_seq_len - __ldg(&cum_offsets[bid]);
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;

View File

@@ -22,7 +22,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -58,7 +58,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -80,7 +80,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -103,7 +103,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -126,7 +126,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -150,7 +150,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -183,7 +183,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -208,7 +208,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -233,7 +233,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -258,7 +258,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -283,7 +283,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -318,7 +318,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -345,7 +345,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -372,7 +372,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -399,7 +399,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -425,7 +425,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -472,7 +472,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -504,7 +504,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -537,7 +537,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -571,7 +571,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -604,7 +604,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -651,7 +651,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -678,7 +678,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -704,7 +704,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -730,7 +730,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -879,7 +879,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -911,8 +911,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1119,7 +1118,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ cum_offsets,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
@@ -1148,8 +1147,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx =
batch_id * max_seq_len - cum_offsets[batch_id];
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
@@ -1750,7 +1748,7 @@ void CascadeAppendWriteCacheKVC8QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1815,7 +1813,7 @@ void CascadeAppendWriteCacheKVC8QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
@@ -1838,7 +1836,7 @@ void CascadeAppendWriteCacheKVC4QKV(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
const paddle::Tensor &tile_ids_per_batch,
@@ -1885,7 +1883,7 @@ void CascadeAppendWriteCacheKVC4QKV(
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,

View File

@@ -26,7 +26,7 @@ void EncoderWriteCacheWithRopeKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,
@@ -143,7 +143,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
batch_ids,
tile_ids,
@@ -170,7 +170,7 @@ void EncoderWriteCacheWithRopeKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
batch_ids,
tile_ids,

View File

@@ -422,7 +422,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids,
@@ -450,7 +449,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const int token_num = qkv_dims[0];
const int max_blocks_per_seq = block_tables.dims()[1];
const int block_size = key_cache.dims()[2];
const int batch_size = cum_offsets.dims()[0];
const int batch_size = seq_lens_this_time.dims()[0];
const int kv_num_heads = key_cache_dims[1];
const int head_dim = key_cache_dims[3];
const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
@@ -463,7 +462,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data.q_num_heads = num_heads;
meta_data.max_blocks_per_seq = max_blocks_per_seq;
meta_data.block_size = block_size;
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = seq_lens_this_time.dims()[0];
phi::GPUContext* dev_ctx = static_cast<phi::GPUContext*>(phi::DeviceContextPool::Instance().Get(qkv.place()));
@@ -528,7 +527,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
kv_batch_ids,
kv_tile_ids,
@@ -595,7 +594,6 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
"seq_lens_encoder",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"block_tables",
"kv_batch_ids",
"kv_tile_ids_per_batch",

View File

@@ -23,7 +23,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
cudaStream_t& stream,
@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
@@ -74,7 +74,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len) {
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -100,7 +100,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
@@ -113,7 +113,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
seq_lens,
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
stream,
@@ -131,7 +131,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
const bool speculate_decoder,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -185,7 +185,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
max_seq_len,
@@ -206,7 +206,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = cu_seqlens_q.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -233,7 +233,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
@@ -247,7 +247,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
seq_lens,
seq_lens_encoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_seq_len,
speculate_decoder,
@@ -266,7 +266,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})
@@ -281,7 +281,7 @@ PD_BUILD_OP(decode_mla_write_cache)
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"cum_offsets",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
.SetInplaceMap({{"kv_cache", "kv_cache_out"}})

View File

@@ -24,7 +24,7 @@ __global__ void decode_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -50,7 +50,7 @@ __global__ void decode_absorb_cache_kernel(
linear_index += step) {
const int ori_bi = linear_index / hidden_size;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
if (seq_lens_encoder[ori_bi] > 0) return;
const int write_seq_id = seq_lens[ori_bi];
@@ -96,7 +96,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -124,7 +124,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -143,7 +143,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
ori_bi,
seq_lens[ori_bi],
token_id,
cum_offsets[ori_bi]);
cu_seqlens_q[ori_bi]);
}
if (bias < nope_hidden_size) { // pe
const uint32_t inner_bias = bias;
@@ -179,7 +179,7 @@ __global__ void prefill_absorb_cache_kernel(
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,

View File

@@ -27,7 +27,7 @@ void DecodeMLAAttentionKernel(
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &cum_offsets,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
int max_dec_len,

View File

@@ -27,7 +27,7 @@ __global__ void append_clear_cache_int8_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -44,7 +44,7 @@ __global__ void append_clear_cache_int8_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -101,7 +101,7 @@ __global__ void append_clear_cache_int4_block(
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
const int max_blocks_per_seq,
@@ -118,7 +118,7 @@ __global__ void append_clear_cache_int4_block(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
if (seq_lens_encoder[bid] > 0) return;
@@ -179,7 +179,7 @@ __global__ void append_speculate_cache_rope_kernel(
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -219,7 +219,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
const int h_bias = bias % head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -235,7 +235,7 @@ __global__ void append_speculate_cache_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cum_offsets[ori_bi]);
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -312,7 +312,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
@@ -352,7 +352,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
const int start_token_idx = ori_bi * max_seq_len - cum_offsets[ori_bi];
const int start_token_idx = cu_seqlens_q[ori_bi];
const int write_seq_id =
seq_lens_decoder[ori_bi] + token_id - start_token_idx;
if (write_seq_id == 0) continue;
@@ -368,7 +368,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cum_offsets[ori_bi]);
cu_seqlens_q[ori_bi]);
}
const int block_offset = write_seq_id % block_size;
@@ -459,7 +459,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -487,7 +487,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -691,7 +691,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -719,7 +719,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
@@ -1069,7 +1069,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1100,7 +1100,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
@@ -1375,7 +1375,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ cum_offsets,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -1406,7 +1406,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int start_token_idx = bid * max_seq_len - cum_offsets[bid];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;

View File

@@ -23,7 +23,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -60,7 +60,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
@@ -83,7 +83,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
@@ -107,7 +107,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -137,7 +137,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -152,7 +152,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -176,7 +176,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -202,7 +202,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* cum_offsets,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
@@ -234,7 +234,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
seq_lens,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
max_blocks_per_seq,
@@ -249,7 +249,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -275,7 +275,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
qkv_out,
block_tables,
padding_offsets,
cum_offsets,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
@@ -302,7 +302,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -350,7 +350,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -377,7 +377,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -410,7 +410,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -443,7 +443,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
cum_offsets.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
@@ -489,7 +489,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -515,7 +515,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -540,7 +540,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,
@@ -567,7 +567,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -24,7 +24,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
const paddle::optional<paddle::Tensor>& qkv_out_scales,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -38,7 +38,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -39,7 +39,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -86,7 +86,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -83,7 +83,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -23,7 +23,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -22,7 +22,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids,

View File

@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cum_offsets,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
@@ -331,7 +331,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len,
@@ -344,7 +344,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
const int max_seq_len);
@@ -370,7 +370,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,

View File

@@ -73,7 +73,6 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -181,7 +180,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -216,7 +214,6 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -50,7 +50,6 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -26,7 +26,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -99,7 +98,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_encoder,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
decoder_batch_ids,
decoder_tile_ids_per_batch,
@@ -128,7 +126,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
max_input_length,
max_len_kv_data,
@@ -151,7 +149,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -201,7 +198,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
meta_data.batch_size = cum_offsets.dims()[0];
meta_data.batch_size = seq_lens_this_time.dims()[0];
switch (query.dtype()) {
case paddle::DataType::BFLOAT16: {
@@ -215,7 +212,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -262,7 +258,6 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
cum_offsets,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -316,7 +311,6 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -371,7 +365,6 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -426,7 +419,6 @@ PD_BUILD_OP(multi_head_latent_attention)
"seq_lens_this_time",
"cu_seqlens_q",
"padding_offsets",
"cum_offsets",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -206,7 +206,7 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,

View File

@@ -33,7 +33,7 @@ def append_attention(
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
padding_offsets: paddle.Tensor,
cum_offsets: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
block_tables: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
encoder_tile_ids_per_batch: paddle.Tensor,
@@ -87,7 +87,7 @@ def append_attention(
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,