Files
FastDeploy/tests/operators/test_tree_mask.py
YuanRisheng 642480f5f6 [CI] Standard unittest (#3606)
* standard unittest

* fix bugs

* fix script
2025-08-26 19:03:11 +08:00

342 lines
12 KiB
Python

import math
import time
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
class TestTreeMask(unittest.TestCase):
def setUp(self):
paddle.seed(0)
self.max_seq_len = 32768
self.encoder_max_partition_size = self.max_seq_len
self.max_partition_size = self.max_seq_len
self.max_dec_len = 1024
self.bsz = 64
self.run_time = 10
self.warm_up = 2
self.block_size = 64
self.head_dim = 128
self.num_q_head = 20
self.num_kv_head = 4
self.dtype = "bfloat16"
self.rope_3d = False
self.use_neox_rotary_style = False
self.CURRENT_Q = [None]
self.TOTAL_K = []
self.TOTAL_V = []
# Initialize cache and block tables
block_num_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size
max_block_num = block_num_per_seq * self.bsz
cache_shape = (
max_block_num,
self.num_kv_head,
self.block_size,
self.head_dim,
)
self.cache_k = paddle.zeros(shape=cache_shape).astype(self.dtype)
self.cache_v = paddle.zeros(shape=cache_shape).astype(self.dtype)
self.block_tables = paddle.zeros(shape=(self.bsz, block_num_per_seq), dtype="int32")
free_list = list(range(max_block_num - 1, -1, -1))
for i in range(self.bsz):
need_block_num = (self.max_seq_len + self.block_size - 1) // self.block_size
for j in range(need_block_num):
block_id = free_list.pop()
self.block_tables[i, j] = block_id
def tearDown(self):
self.CURRENT_Q = [None]
self.TOTAL_K = []
self.TOTAL_V = []
def split_qkv(self, qkv, bsz, seq_len):
qkv = qkv.reshape([bsz, seq_len, -1, self.head_dim])
q = qkv[:, :, : self.num_q_head, :]
self.CURRENT_Q[0] = q
k = qkv[:, :, self.num_q_head : self.num_q_head + self.num_kv_head, :]
self.TOTAL_K.append(k)
v = qkv[:, :, self.num_q_head + self.num_kv_head :, :]
self.TOTAL_V.append(v)
def get_padding_offset(self, bsz, seq_lens_this_time, seq_lens_decoder):
batch_id_per_token = []
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
cum_seq_len_q = 0
cum_seq_len_k = 0
for i in range(bsz):
seq_len_now = seq_lens_this_time[i]
seq_len_dec_now = seq_lens_decoder[i]
for j in range(seq_len_now):
batch_id_per_token.append(i)
cum_seq_len_q += seq_len_now
cum_seq_len_k += seq_len_now + seq_len_dec_now
cu_seqlens_q[i + 1] = cum_seq_len_q
cu_seqlens_k[i + 1] = cum_seq_len_k
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
def ref_attention(self, q, k, v, mask):
q = q.transpose([0, 2, 1, 3])
if len(k) > 1:
k = paddle.concat(k, axis=1)
else:
k = k[0]
k = k.transpose([0, 2, 1, 3])
if len(v) > 1:
v = paddle.concat(v, axis=1)
else:
v = v[0]
v = v.transpose([0, 2, 1, 3])
total_len = k.shape[2]
scores = (
q.reshape([self.bsz, self.num_kv_head, -1, self.head_dim])
@ k.transpose([0, 1, 3, 2])
* (1.0 / math.sqrt(self.head_dim))
)
scores = scores.reshape([self.bsz, self.num_q_head, -1, total_len])
if mask is not None:
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
elif mask.ndim == 3:
mask = mask.unsqueeze(1)
scores = paddle.add(scores, mask)
weights = F.softmax(scores, axis=-1)
o = weights.reshape([self.bsz, self.num_kv_head, -1, total_len]) @ v
return (
o.reshape([self.bsz, self.num_q_head, -1, self.head_dim])
.transpose([0, 2, 1, 3])
.reshape([-1, self.num_q_head, self.head_dim])
)
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
if prefill:
seq_lens_enc = [
q_len,
] * self.bsz
else:
seq_lens_enc = [
0,
] * self.bsz
seq_lens_dec = [
kv_len,
] * self.bsz
seq_lens_cur = [
q_len,
] * self.bsz
token_num = sum(seq_lens_cur)
decoder_step_token_num = 1 if prefill else q_len
seq_lens_encoder = paddle.to_tensor(seq_lens_enc, "int32")
seq_lens_this_time = paddle.to_tensor(seq_lens_cur, "int32")
seq_lens_decoder = paddle.to_tensor(seq_lens_dec, "int32")
batch_id_per_token, cu_seqlens_q, cu_seqlens_k = self.get_padding_offset(
self.bsz, seq_lens_this_time, seq_lens_decoder
)
qkv_varlen_shape = [token_num, (self.num_q_head + 2 * self.num_kv_head) * self.head_dim]
rotary_embs_shape = [
2,
1,
self.max_seq_len,
1,
self.head_dim if self.use_neox_rotary_style else self.head_dim // 2,
]
qkv = paddle.randn(shape=qkv_varlen_shape).astype(self.dtype)
self.split_qkv(qkv, self.bsz, q_len)
rotary_embs = paddle.randn(shape=rotary_embs_shape).astype("float32")
rotary_embs[0, :, :, :, :] = 1
rotary_embs[1, :, :, :, :] = 0
cache_k_scale = None
cache_v_scale = None
cache_k_out_scale = None
cache_v_out_scale = None
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decode_max_tile_size = (
self.bsz
* (decoder_step_token_num * (self.num_q_head // self.num_kv_head) + decoder_block_shape_q - 1)
/ decoder_block_shape_q
)
decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
paddle.device.synchronize()
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
max_len_kv,
) = get_block_shape_and_split_kv_block(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
self.num_q_head // self.num_kv_head,
self.block_size,
decoder_step_token_num,
)
s_time = 0
for i in range(self.run_time + self.warm_up):
if i == self.warm_up:
s_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
batch_id_per_token,
cu_seqlens_q,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_tensor_cpu,
max_len_kv,
rotary_embs,
attn_mask,
None,
None,
cache_k_scale,
cache_v_scale,
cache_k_out_scale,
cache_v_out_scale,
None,
None,
None,
None,
None,
None,
None,
None,
1e-6,
"bf16",
"none",
self.use_neox_rotary_style,
self.rope_3d,
self.max_seq_len,
0.0,
0.0,
-1.0,
encoder_block_shape_q,
decoder_block_shape_q,
self.max_partition_size,
self.encoder_max_partition_size,
decoder_step_token_num,
True,
decoder_step_token_num > 1,
)
paddle.device.synchronize()
e_time = time.time()
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
def test_naive_speculative_decoding(self):
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
self.run_append_c16_attention(prefill_len, 0, True)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
def test_mask(self):
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
mask_append_attn = mask[:, :, prefill_len:]
mask_append_attn = paddle.where(
mask_append_attn == 1,
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
)
self.run_append_c16_attention(prefill_len, 0, True)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
def test_tree_mask(self):
prefill_len = 8192
dec_len_q = 5
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask[:, 2, prefill_len + 1] = 0
mask[:, 3, prefill_len + 2] = 0
mask[:, 4, prefill_len + 1] = 0
mask[:, 4, prefill_len + 3] = 0
mask_ref = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
mask_append_attn = mask[:, :, prefill_len:]
mask_append_attn = paddle.where(
mask_append_attn == 1,
paddle.full_like(mask_append_attn, fill_value=False, dtype=bool),
paddle.full_like(mask_append_attn, fill_value=True, dtype=bool),
)
self.run_append_c16_attention(prefill_len, 0, True)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, mask_append_attn)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask_ref)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)
if __name__ == "__main__":
unittest.main()