make append_attn supports mask_offset (#3138)

* make append_attn supports mask_offset

* add unittest
This commit is contained in:
lzy
2025-08-14 18:40:55 +08:00
committed by GitHub
parent 6031f9a5f5
commit 1e06b9fa6d
10 changed files with 88 additions and 20 deletions

View File

@@ -349,6 +349,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = True
self.use_mask_offset = False
self.init_tensor()
def init_tensor(self):
@@ -404,6 +405,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
self.token_num = self.padding_offset.shape[0]
self.mask_offset = None
if self.use_mask_offset:
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
for i in range(self.batch_size):
for j in range(self.seq_len):
self.mask_offset[i * self.seq_len + j] = j
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
@@ -505,6 +512,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
@@ -560,6 +568,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
# encoder
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
self.seq_lens_this_time = self.seq_lens_encoder
if self.use_mask_offset:
print("encoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(attn_mask=self.attention_mask)
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
@@ -590,6 +600,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
if self.use_mask_offset:
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
for i in range(self.batch_size):
self.mask_offset[i] = self.seq_lens_dec[i]
print("decoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
@@ -614,6 +629,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = False
self.use_mask_offset = True
self.init_tensor()