[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)

* remove padding_offsets from atten
This commit is contained in:
周周周
2025-07-17 18:41:31 +08:00
committed by GitHub
parent d49f8fb30a
commit ddb10ac509
50 changed files with 311 additions and 288 deletions

View File

@@ -46,7 +46,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
@@ -165,7 +165,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_this_time,
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
lambda_batch_ids,
@@ -202,7 +202,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
kv_batch_ids,
@@ -274,7 +274,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -297,7 +297,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -322,7 +322,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -346,7 +346,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
qkv_out, // [token_num, num_heads, head_dim]
seq_lens_decoder,
seq_lens_encoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
rotary_embs,
@@ -403,7 +403,7 @@ std::vector<paddle::Tensor> AppendAttention(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
@@ -473,7 +473,7 @@ std::vector<paddle::Tensor> AppendAttention(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,
@@ -550,7 +550,7 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
@@ -610,7 +610,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const paddle::DataType& seq_lens_encoder_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
@@ -688,7 +688,7 @@ PD_BUILD_STATIC_OP(append_attention)
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"padding_offsets",
"batch_id_per_token",
"cu_seqlens_q",
"block_tables",
"encoder_batch_ids",