fix test_append_attention_with_output.py (#3831)

Co-authored-by: plusNew001 <95567040+plusNew001@users.noreply.github.com>
This commit is contained in:
lzy
2025-09-03 14:07:50 +08:00
committed by GitHub
parent 54b458fd98
commit 2527eb0e4e

View File

@@ -407,10 +407,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.token_num = self.padding_offset.shape[0]
self.mask_offset = None
if self.use_mask_offset:
self.mask_offset = paddle.full(self.seq_len * self.batch_size, 0, "int32")
self.mask_offset = paddle.full(self.batch_size * self.seq_len * 2, 0, "int32")
for i in range(self.batch_size):
for j in range(self.seq_len):
self.mask_offset[i * self.seq_len + j] = j
self.mask_offset[i * self.seq_len * 2 + j * 2] = 0
self.mask_offset[i * self.seq_len * 2 + j * 2 + 1] = j + 1
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
@@ -603,9 +604,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
if self.use_mask_offset:
self.mask_offset = paddle.full(self.batch_size, 0, "int32")
self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32")
for i in range(self.batch_size):
self.mask_offset[i] = self.seq_lens_dec[i]
self.mask_offset[i * 2] = 0
self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1
print("decoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)