[Inference, rename] remove padding_offsets from atten use batch_id_per_token (#2880)

* remove padding_offsets from atten
This commit is contained in:
周周周
2025-07-17 18:41:31 +08:00
committed by GitHub
parent d49f8fb30a
commit ddb10ac509
50 changed files with 311 additions and 288 deletions

View File

@@ -85,8 +85,8 @@ class ForwardMeta():
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
# Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids
padding_offset: Optional[paddle.Tensor] = None
# batch_id_per_token tensor, used to indicate which token belongs which batch after padding removal to the original input_ids
batch_id_per_token: Optional[paddle.Tensor] = None
# Accumulated sequence length of query
cu_seqlens_q: Optional[paddle.Tensor] = None
# Accumulated sequence length of key

View File

@@ -216,7 +216,7 @@ class AppendAttentionBackend(AttentionBackend):
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.padding_offset,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
metadata.encoder_batch_ids,

View File

@@ -32,7 +32,7 @@ def append_attention(
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
padding_offsets: paddle.Tensor,
batch_id_per_token: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
block_tables: paddle.Tensor,
encoder_batch_ids: paddle.Tensor,
@@ -86,7 +86,7 @@ def append_attention(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
encoder_batch_ids,

View File

@@ -72,7 +72,7 @@ def pre_process(
Return:
ids_remove_padding:
cum_offsets:
padding_offset:
batch_id_per_token:
cu_seqlens_q:
cu_seqlens_k:
"""
@@ -85,7 +85,7 @@ def pre_process(
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = speculate_get_padding_offset(
@@ -115,12 +115,12 @@ def pre_process(
(
ids_remove_padding,
cum_offsets,
padding_offset,
batch_id_per_token,
cu_seqlens_q,
cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
seq_lens_this_time)
return (ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q,
return (ids_remove_padding, cum_offsets, batch_id_per_token, cu_seqlens_q,
cu_seqlens_k, output_cum_offsets, output_padding_offset)