fix mask_offset in append_attn (#3745)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* fix mask_offset in append_attn

* fix test
This commit is contained in:
lizhenyun01
2025-08-31 15:03:16 +08:00
committed by GitHub
parent 753772ace8
commit bed09ae8f8
8 changed files with 46 additions and 44 deletions

View File

@@ -63,7 +63,6 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
mask_offset: Optional[paddle.Tensor] = None
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -142,7 +141,6 @@ class AppendAttentionBackend(AttentionBackend):
metadata.block_tables = forward_meta.block_tables
metadata.rotary_embs = forward_meta.rotary_embs
metadata.attn_mask = forward_meta.attn_mask
metadata.mask_offset = forward_meta.attn_mask_offsets
metadata.pre_caches_length = forward_meta.pre_caches_length
(
metadata.encoder_batch_ids,
@@ -303,7 +301,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.mask_offset,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
@@ -358,7 +356,7 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "cache_v_zp", None),
layer.linear_shift,
layer.linear_smooth,
metadata.mask_offset,
forward_meta.attn_mask_offsets,
metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),