[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)

* remove padding_offsets from atten
This commit is contained in:
周周周
2025-07-17 18:41:31 +08:00
committed by GitHub
parent d49f8fb30a
commit ddb10ac509
50 changed files with 311 additions and 288 deletions

View File

@@ -46,7 +46,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
@@ -202,7 +202,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
@@ -274,7 +274,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -297,7 +297,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -322,7 +322,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -346,7 +346,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -403,7 +403,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
@@ -473,7 +473,7 @@ std::vector<paddle::Tensor> AppendAttention(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
@@ -550,7 +550,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
@@ -688,7 +688,7 @@ PD_BUILD_STATIC_OP(append_attention)
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",

View File

@@ -773,7 +773,7 @@ void MultiQueryAppendAttention(
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
@@ -1007,7 +1007,8 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1240,7 +1241,8 @@ void MultiQueryAppendAttention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1287,7 +1289,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -1350,7 +1352,7 @@ void CascadeAppendAttentionC16Kernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,

View File

@@ -960,7 +960,7 @@ void MultiQueryAppendC4Attention(
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
@@ -1219,7 +1219,8 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1477,7 +1478,8 @@ void MultiQueryAppendC4Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1524,7 +1526,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -1591,7 +1593,7 @@ void CascadeAppendAttentionC4Kernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,

View File

@@ -897,7 +897,7 @@ void MultiQueryAppendC8Attention(
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
@@ -1179,7 +1179,8 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1450,7 +1451,8 @@ void MultiQueryAppendC8Attention(
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
@@ -1497,7 +1499,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -1562,7 +1564,7 @@ void CascadeAppendAttentionC8Kernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,

View File

@@ -1852,7 +1852,7 @@ __global__ void merge_multi_chunks_kernel(
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ padding_offsets,
const int* __restrict__ batch_id_per_token,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
T* __restrict__ out,
@@ -1866,8 +1866,7 @@ __global__ void merge_multi_chunks_kernel(
const int head_dim) {
const int vid = threadIdx.x, hid = threadIdx.y;
const int qid = blockIdx.x;
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const uint32_t bid = batch_id_per_token[qid];
if (seq_lens_q[bid] <= 0 || seq_lens_kv[bid] <= 0) {
return;
}
@@ -2240,7 +2239,8 @@ __global__ void merge_multi_chunks_v2_kernel(
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
OutT *__restrict__ out,
@@ -2259,9 +2259,8 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const uint32_t local_seq_id = ori_token_id % max_seq_len;
const uint32_t bid = batch_id_per_token[qid];
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
int seq_len_kv = seq_lens_kv[bid];

View File

@@ -40,7 +40,7 @@ void CascadeAppendAttentionC16Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -85,7 +85,7 @@ void CascadeAppendAttentionC8Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -130,7 +130,7 @@ void CascadeAppendAttentionC4Kernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -175,7 +175,7 @@ void CascadeAppendAttentionKernel(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -211,7 +211,7 @@ void CascadeAppendAttentionKernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
@@ -246,7 +246,7 @@ void CascadeAppendAttentionKernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
@@ -281,7 +281,7 @@ void CascadeAppendAttentionKernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,
@@ -316,7 +316,7 @@ void CascadeAppendAttentionKernel(
seq_lens_q,
seq_lens_kv,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_table,
batch_ids,

View File

@@ -317,7 +317,7 @@ void MultiQueryDecoderAttention(
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q,
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const int max_seq_len,
@@ -483,7 +483,7 @@ void DecodeMLAAttentionKernel(
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
@@ -513,7 +513,7 @@ void DecodeMLAAttentionKernel(
{DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE,
{DISPATCH_DEAL_EACH_TIME(deal_each_time, DEAL_EACH_TIME,
{MultiQueryDecoderAttention<T, GROUP_SIZE, HEAD_DIM_QK, HEAD_DIM_V, BLOCK_SIZE, CAUSAL, 2, 16, DEAL_EACH_TIME>(
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, padding_offsets, cu_seqlens_q,
meta_data, stream, q, cache_k, cache_v, attn_mask, shift_bias, smooth_weight, seq_lens_q, seq_lens_kv, batch_id_per_token, cu_seqlens_q,
block_table, max_seq_len, max_dec_len, rope_scale, rope_theta, softmax_scale, in_scale, out);})})})})})});
}
@@ -527,7 +527,7 @@ template void DecodeMLAAttentionKernel<paddle::bfloat16>(
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,
@@ -548,7 +548,7 @@ template void DecodeMLAAttentionKernel<paddle::float16>(
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,

View File

@@ -28,7 +28,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -134,7 +134,7 @@ __global__ void append_decode_cache_T_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -254,7 +254,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -366,7 +366,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -498,7 +498,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -745,7 +745,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1047,7 +1047,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1346,7 +1346,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1739,7 +1739,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2034,7 +2034,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2362,7 +2362,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -2732,7 +2732,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]

View File

@@ -21,7 +21,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -57,7 +57,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -79,7 +79,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -102,7 +102,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -125,7 +125,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -149,7 +149,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -182,7 +182,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -207,7 +207,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -232,7 +232,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -257,7 +257,7 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -282,7 +282,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -317,7 +317,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -344,7 +344,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -371,7 +371,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -398,7 +398,7 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -424,7 +424,7 @@ void DecoderWriteCacheWithRoPEKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -471,7 +471,7 @@ void DecoderWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -503,7 +503,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -536,7 +536,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -570,7 +570,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -603,7 +603,7 @@ void DecoderWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -650,7 +650,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -677,7 +677,7 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -703,7 +703,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -729,7 +729,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -23,7 +23,7 @@ void DecoderWriteCacheWithRoPEKernel(
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -23,7 +23,8 @@ __global__ void VariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales, // [3, num_head, dim_head]
@@ -52,8 +53,7 @@ __global__ void VariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
@@ -61,7 +61,7 @@ __global__ void VariableLengthRotaryKernel(
const int hi = qkv_bias / last_dim;
const int h_bias = qkv_bias % last_dim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int bias_idx = qkv_id * hidden_size + hi * last_dim + h_bias;
@@ -107,7 +107,8 @@ __global__ void VariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
@@ -130,8 +131,7 @@ __global__ void VariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
@@ -139,7 +139,7 @@ __global__ void VariableLengthRotaryKernel(
const int hi = qkv_bias / last_dim;
const int h_bias = qkv_bias % last_dim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx = token_idx * 3 * hidden_size +
@@ -167,7 +167,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales, // [3, num_head, dim_head]
@@ -199,8 +200,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
@@ -208,7 +208,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int hi = qkv_bias / half_lastdim;
const int h_bias = qkv_bias % half_lastdim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
const int bias_idx_left =
@@ -261,7 +261,8 @@ __global__ void NeoxVariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
@@ -285,8 +286,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int qkv_id = bias / hidden_size;
@@ -294,7 +294,7 @@ __global__ void NeoxVariableLengthRotaryKernel(
const int hi = qkv_bias / half_lastdim;
const int h_bias = qkv_bias % half_lastdim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
const int base_idx_left = token_idx * 3 * full_hidden_size +
@@ -327,7 +327,8 @@ __global__ void GQAVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales, // [3, q_num_head, dim_head]
@@ -357,14 +358,13 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t bias_idx = hi * last_dim + h_bias;
@@ -410,7 +410,8 @@ __global__ void GQAVariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
T *qkv_out,
@@ -434,14 +435,13 @@ __global__ void GQAVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];;
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t base_idx =
@@ -472,7 +472,8 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const float *qkv_out_scales,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const T *qkv_biases,
@@ -504,15 +505,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const int *qkv,
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
int ori_seq_id;
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t bias_idx = hi * last_dim + h_bias;
@@ -561,7 +560,8 @@ template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const T *qkv_biases,
@@ -590,15 +590,13 @@ __global__ void GQAVariableLengthRotaryQuantKVKernel(const T *qkv,
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
int ori_seq_id;
ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
const int64_t bias_idx = hi * last_dim + h_bias;
@@ -645,7 +643,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const int *qkv,
const float *cos_emb, // [1, 1, seq_len, dim_head / 2]
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales, // [3, q_num_head, dim_head]
@@ -676,14 +675,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / half_lastdim;
const int h_bias = bias % half_lastdim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
const int bias_idx_left = hi * last_dim + h_bias;
@@ -736,7 +734,8 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales,
@@ -761,14 +760,13 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / half_lastdim;
const int h_bias = bias % half_lastdim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int emb_idx = ori_seq_id * last_dim + h_bias;
const int base_idx_left =
@@ -805,7 +803,8 @@ __global__ void cache_kernel(
T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size]
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int *__restrict__ padding_offsets, // [num_tokens]
const int *__restrict__ batch_id_per_token, // [num_tokens]
const int *__restrict__ cu_seqlens_q, // [bsz]
const int *__restrict__ seq_lens, // [bsz]
const int *__restrict__ seq_lens_decoder, // [bsz]
const int max_seq_len,
@@ -831,11 +830,9 @@ __global__ void cache_kernel(
const uint32_t qkv_bias = bias % hidden_size;
const uint32_t hi = qkv_bias / head_size;
const uint32_t h_bias = qkv_bias % head_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
const uint32_t ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int32_t *block_table_now = nullptr;
@@ -878,7 +875,7 @@ __global__ void append_write_cache_kv_c8_qkv(
const int *__restrict__ tile_ids,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
@@ -1117,7 +1114,7 @@ __global__ void append_write_cache_kv_c4_qkv(
const int *__restrict__ tile_ids,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ padding_offsets,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
@@ -1405,7 +1402,8 @@ void rotary_qk_variable(
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
@@ -1437,7 +1435,8 @@ void rotary_qk_variable(
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
@@ -1453,7 +1452,8 @@ void rotary_qk_variable(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
@@ -1471,7 +1471,8 @@ void rotary_qk_variable(
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
@@ -1487,7 +1488,8 @@ void rotary_qk_variable(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
@@ -1506,7 +1508,8 @@ void gqa_rotary_qk_variable(
const float *qkv_out_scales, // [3, num_head, dim_head]
const T *qkv_bias,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
@@ -1541,7 +1544,8 @@ void gqa_rotary_qk_variable(
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
@@ -1559,7 +1563,8 @@ void gqa_rotary_qk_variable(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out,
@@ -1579,7 +1584,8 @@ void gqa_rotary_qk_variable(
reinterpret_cast<const int *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
@@ -1596,7 +1602,8 @@ void gqa_rotary_qk_variable(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
@@ -1620,7 +1627,8 @@ void gqa_rotary_qk_quant_variable(
const T *cache_k_scales,
const T *cache_v_scales,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int token_num,
@@ -1652,7 +1660,8 @@ void gqa_rotary_qk_quant_variable(
cos_emb,
sin_emb,
qkv_out_scales,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_bias,
@@ -1671,7 +1680,8 @@ void gqa_rotary_qk_quant_variable(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_bias,
@@ -1697,7 +1707,8 @@ void CascadeAppendWriteCacheKVQKV(
&qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
// kv_num_heads, head_dim] if GQA)
const paddle::Tensor &block_table,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const int max_seq_len,
@@ -1723,7 +1734,8 @@ void CascadeAppendWriteCacheKVQKV(
reinterpret_cast<T *>(key_cache_out->data<T>()),
reinterpret_cast<T *>(value_cache_out->data<T>()),
block_table.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
max_seq_len,
@@ -1747,7 +1759,7 @@ void CascadeAppendWriteCacheKVC8QKV(
const paddle::Tensor &cache_v_scale, // [num_kv_heads, head_dim]
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
@@ -1812,7 +1824,7 @@ void CascadeAppendWriteCacheKVC8QKV(
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
@@ -1835,7 +1847,7 @@ void CascadeAppendWriteCacheKVC4QKV(
const paddle::Tensor &cache_v_zp, // [num_kv_heads, head_dim]
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
const paddle::Tensor &batch_ids,
@@ -1882,7 +1894,7 @@ void CascadeAppendWriteCacheKVC4QKV(
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,

View File

@@ -25,7 +25,7 @@ void EncoderWriteCacheWithRopeKernel(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
@@ -63,7 +63,8 @@ void EncoderWriteCacheWithRopeKernel(
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
@@ -82,7 +83,8 @@ void EncoderWriteCacheWithRopeKernel(
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
@@ -103,7 +105,8 @@ void EncoderWriteCacheWithRopeKernel(
cache_k_scale ? cache_k_scale.get().data<T>() : nullptr,
cache_v_scale ? cache_v_scale.get().data<T>() : nullptr,
rotary_embs.get().data<float>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
token_num,
@@ -123,7 +126,8 @@ void EncoderWriteCacheWithRopeKernel(
CascadeAppendWriteCacheKVQKV<T>(meta_data,
*qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
max_seq_len,
@@ -142,7 +146,7 @@ void EncoderWriteCacheWithRopeKernel(
cache_v_scale.get(),
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
batch_ids,
@@ -169,7 +173,7 @@ void EncoderWriteCacheWithRopeKernel(
cache_v_zp.get(),
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
batch_ids,

View File

@@ -25,7 +25,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *padding_offsets,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const int *cu_seqlens_k,
@@ -52,14 +53,13 @@ __global__ void GQAVariableLengthRotarySplitKernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_token_idx = token_idx + padding_offsets[token_idx];
const int ori_bi = ori_token_idx / seq_len;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id = ori_token_idx % seq_len + seq_lens_decoder[ori_bi];
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
@@ -108,9 +108,10 @@ void gqa_rotary_qk_split_variable(
T *v,
const T *qkv_input,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const int *padding_offsets,
const int *batch_id_per_token,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int *cu_seqlens_q,
const int *cu_seqlens_k,
const int token_num,
const int num_heads,
@@ -133,7 +134,8 @@ void gqa_rotary_qk_split_variable(
qkv_input,
cos_emb,
sin_emb,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
@@ -421,7 +423,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& kv_batch_ids,
const paddle::Tensor& kv_tile_ids,
@@ -492,9 +494,10 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
@@ -509,7 +512,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
max_seq_len,
@@ -526,7 +530,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
cache_v_quant_scales.get(),
seq_lens_this_time,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
@@ -593,7 +597,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
"seq_lens_this_time",
"seq_lens_encoder",
"seq_lens_decoder",
"padding_offsets",
"batch_id_per_token",
"block_tables",
"kv_batch_ids",
"kv_tile_ids_per_batch",

View File

@@ -22,7 +22,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
@@ -53,7 +53,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_decoder.data<int>(),
@@ -73,7 +73,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -99,7 +99,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -112,7 +112,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -130,7 +130,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
const paddle::Tensor& kv_pe,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const int max_seq_len,
@@ -164,7 +164,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCache(
reinterpret_cast<DataType_*>(const_cast<data_t*>(kv_pe.data<data_t>())),
reinterpret_cast<DataType_*>(kv_cache->data<data_t>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -205,7 +205,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -232,7 +232,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -246,7 +246,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
kv_pe,
seq_lens,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_seq_len,
@@ -265,7 +265,7 @@ PD_BUILD_OP(prefill_mla_write_cache)
"kv_cache",
"seq_lens",
"seq_lens_decoder",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})
@@ -280,7 +280,7 @@ PD_BUILD_OP(decode_mla_write_cache)
"kv_cache",
"seq_lens",
"seq_lens_encoder",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables"})
.Outputs({"kv_cache_out"})

View File

@@ -95,7 +95,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -121,7 +121,7 @@ __global__ void speculate_decode_absorb_cache_kernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int start_token_idx = cu_seqlens_q[ori_bi];
@@ -178,7 +178,7 @@ __global__ void prefill_absorb_cache_kernel(
T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size,
// nope_size]
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_decoder, // [bsz]
@@ -204,11 +204,9 @@ __global__ void prefill_absorb_cache_kernel(
linear_index += step) {
const uint32_t token_idx = linear_index / hidden_size;
const uint32_t bias = linear_index % hidden_size;
const uint32_t ori_token_idx = token_idx + padding_offsets[token_idx];
const uint32_t ori_bi = ori_token_idx / max_seq_len;
const uint32_t ori_bi = batch_id_per_token[token_idx];
if (seq_lens[ori_bi] == 0) continue;
const uint32_t ori_seq_id =
ori_token_idx % max_seq_len + seq_lens_decoder[ori_bi];
const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int* block_table_now = nullptr;
block_table_now = block_tables + ori_bi * max_blocks_per_seq;

View File

@@ -26,7 +26,7 @@ void DecodeMLAAttentionKernel(
const paddle::optional<paddle::Tensor>& smooth_weight,
const paddle::Tensor &seq_lens_q, // q_seq_len is 1
const paddle::Tensor &seq_lens_kv,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_table,
int max_seq_len,

View File

@@ -26,7 +26,7 @@ __global__ void append_clear_cache_int8_block(
// block_size, head_size // 2]
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -41,8 +41,8 @@ __global__ void append_clear_cache_int8_block(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
@@ -100,7 +100,7 @@ __global__ void append_clear_cache_int4_block(
// block_size, head_size // 2]
const int* __restrict__ seq_lens,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_encoder, // [bsz]
const int max_seq_len,
@@ -115,8 +115,8 @@ __global__ void append_clear_cache_int4_block(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
@@ -178,7 +178,7 @@ __global__ void append_speculate_cache_rope_kernel(
// head_size // 2]
T* __restrict__ q_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -214,7 +214,7 @@ __global__ void append_speculate_cache_rope_kernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v
@@ -311,7 +311,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens_decoder, // [bsz]
const float* __restrict__ cos_emb,
@@ -347,7 +347,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
linear_index < elem_cnt;
linear_index += step) {
const int token_id = linear_index / half_hidden_size;
const int ori_bi = (token_id + padding_offsets[token_id]) / max_seq_len;
const int ori_bi = batch_id_per_token[token_id];
if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
@@ -458,7 +458,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -484,8 +484,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
@@ -690,7 +690,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -716,8 +716,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
const int wid = tid / 32;
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
@@ -1068,7 +1068,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1097,8 +1097,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;
@@ -1374,7 +1374,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ padding_offsets, // [num_tokens]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
@@ -1403,8 +1403,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
const int lane_id = tid % 32;
const int token_id = blockIdx.x;
const int ori_token_id = token_id + padding_offsets[token_id];
const int bid = ori_token_id / max_seq_len;
const int bid = batch_id_per_token[token_id];
const int start_token_idx = cu_seqlens_q[bid];
const int head_idx = blockIdx.y * NUM_WARPS + wid;

View File

@@ -22,7 +22,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -59,7 +59,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
@@ -82,7 +82,7 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
@@ -106,7 +106,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -136,7 +136,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
seq_lens,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
@@ -151,7 +151,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -175,7 +175,7 @@ void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -201,7 +201,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* padding_offsets,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
@@ -233,7 +233,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
seq_lens,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
max_seq_len,
@@ -248,7 +248,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -274,7 +274,7 @@ void append_speculate_cache_int4_rope(const QKV_TYPE* qkv,
value_cache,
qkv_out,
block_tables,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
@@ -301,7 +301,7 @@ void SpeculateWriteCacheWithRoPEKernel(
const paddle::Tensor& qkv,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -349,7 +349,7 @@ void SpeculateWriteCacheWithRoPEKernel(
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -376,7 +376,7 @@ void SpeculateWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -409,7 +409,7 @@ void SpeculateWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -442,7 +442,7 @@ void SpeculateWriteCacheWithRoPEKernel(
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
padding_offsets.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
@@ -488,7 +488,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
// gqa_group_size, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -514,7 +514,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
// gqa_group_size, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -539,7 +539,7 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
// gqa_group_size, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,
@@ -566,7 +566,7 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
// gqa_group_size, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -23,7 +23,7 @@ void SpeculateWriteCacheWithRoPEKernel(
// gqa_group_size, head_dim] if GQA)
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::optional<paddle::Tensor>& rotary_embs,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::bfloat16
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, paddle::float8_e
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, paddle::float8_e4
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC16Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::bfloat16>
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::bfloat16, int8_t>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -37,7 +37,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC4Kernel<paddle::float16, int8_t>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -38,7 +38,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -85,7 +85,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -80,7 +80,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -82,7 +82,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -36,7 +36,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,
@@ -81,7 +81,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const paddle::Tensor& seq_lens_q,
const paddle::Tensor& seq_lens_kv,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_table,
const paddle::Tensor& batch_ids,

View File

@@ -22,7 +22,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,

View File

@@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,

View File

@@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,

View File

@@ -21,7 +21,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,

View File

@@ -54,7 +54,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &padding_offsets, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &block_tables, const paddle::Tensor &encoder_batch_ids,
const paddle::Tensor &encoder_tile_ids_per_batch,
const paddle::Tensor &encoder_num_blocks,
@@ -94,7 +94,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &padding_offsets,
const paddle::Tensor &batch_id_per_token,
const paddle::Tensor &block_tables, const paddle::Tensor &kv_batch_ids,
const paddle::Tensor &kv_tile_ids, const paddle::Tensor &kv_num_blocks,
const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids,
@@ -330,7 +330,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -343,7 +343,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
const paddle::Tensor& kv_cache,
const paddle::Tensor& seq_lens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const std::string& cache_quant_type_str,
@@ -369,7 +369,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,

View File

@@ -46,7 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *padding_offset,
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;

View File

@@ -53,7 +53,7 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;

View File

@@ -85,8 +85,8 @@ class ForwardMeta():
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
# Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids
padding_offset: Optional[paddle.Tensor] = None
# batch_id_per_token tensor, used to indicate which token belongs which batch after padding removal to the original input_ids
batch_id_per_token: Optional[paddle.Tensor] = None
# Accumulated sequence length of query
cu_seqlens_q: Optional[paddle.Tensor] = None
# Accumulated sequence length of key

View File

@@ -216,7 +216,7 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.padding_offset,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,

View File

@@ -32,7 +32,7 @@ def append_attention(
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
padding_offsets: paddle.Tensor,
batch_id_per_token: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
block_tables: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
@@ -86,7 +86,7 @@ def append_attention(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,

View File

@@ -72,7 +72,7 @@ def pre_process(
Return:
ids_remove_padding:
cum_offsets:
padding_offset:
batch_id_per_token:
cu_seqlens_q:
cu_seqlens_k:
"""
@@ -85,7 +85,7 @@ def pre_process(
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
@@ -115,12 +115,12 @@ def pre_process(
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
seq_lens_this_time)
return (ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q,
return (ids_remove_padding, cum_offsets, batch_id_per_token, cu_seqlens_q,
cu_seqlens_k, output_cum_offsets, output_padding_offset)

View File

@@ -272,8 +272,8 @@ class MTPProposer(Proposer):
self.main_model_inputs["ids_remove_padding"])
self.model_inputs["cum_offsets"] = paddle.clone(
self.main_model_inputs["cum_offsets"])
self.model_inputs["padding_offset"] = paddle.clone(
self.main_model_inputs["padding_offset"])
self.model_inputs["batch_id_per_token"] = paddle.clone(
self.main_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(
self.main_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(
@@ -447,7 +447,7 @@ class MTPProposer(Proposer):
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
cum_offsets=self.model_inputs["cum_offsets"],
padding_offset=self.model_inputs["padding_offset"],
batch_id_per_token=self.model_inputs["batch_id_per_token"],
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
block_tables=self.model_inputs["block_tables"],
@@ -542,7 +542,7 @@ class MTPProposer(Proposer):
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
@@ -560,8 +560,8 @@ class MTPProposer(Proposer):
self.model_inputs["ids_remove_padding"].copy_(
ids_remove_padding, False)
self.model_inputs["cum_offsets"].copy_(cum_offsets, False)
self.model_inputs["padding_offset"].copy_(
padding_offset, False)
self.model_inputs["batch_id_per_token"].copy_(
batch_id_per_token, False)
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# for speculative decoding

View File

@@ -559,7 +559,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1],
0,
dtype='int32')
self.share_inputs["padding_offset"] = paddle.full([max_num_seqs, 1],
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1],
0,
dtype='int32')
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1],
@@ -670,7 +670,7 @@ class GPUModelRunner(ModelRunnerBase):
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
output_cum_offsets,
@@ -685,7 +685,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding,
False)
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
self.share_inputs["padding_offset"].copy_(padding_offset, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -762,7 +762,7 @@ class GPUModelRunner(ModelRunnerBase):
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
cum_offsets=self.share_inputs["cum_offsets"],
padding_offset=self.share_inputs["padding_offset"],
batch_id_per_token=self.share_inputs["batch_id_per_token"],
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
block_tables=self.share_inputs["block_tables"],