Files
FastDeploy/test/layers/test_append_attention.py
2025-06-29 23:29:37 +00:00

625 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
import time
paddle.seed(10)
class RopeEmbedding:
def __init__(self, use_neox_rotary_style=False):
self.use_neox_rotary_style = use_neox_rotary_style
self.base = 10000
def get_neox_style_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim),
dtype="float32")
inv_freq = self.base**(-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, head_dim))
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
def get_rotary_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros(
(2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32"
)
inv_freq = self.base ** (
-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim
)
# shape: [B, S, D/2]
freqs = paddle.einsum(
"ij,k->ijk", position_ids.cast("float32"), inv_freq
)
# shape: [B, S, D/2]
emb = paddle.stack([freqs], axis=-1).reshape(
(bsz, max_seq_len, head_dim // 2)
)
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
def _apply_rope(self, rotary_emb, q, k, v=None, causal=False):
# sin [sequence_length, embed_size_per_head//2]
# cos [sequence_length, embed_size_per_head//2]
# sin, cos = paddle.chunk(rp, 2, axis=-1)
seq, head_dim = q.shape[2], q.shape[3]
cos, sin = paddle.chunk(rotary_emb, 2, axis=0)
cos = paddle.squeeze(cos, axis=0).transpose(
[0, 2, 1, 3])[:, :, :seq, :]
sin = paddle.squeeze(sin, axis=0).transpose(
[0, 2, 1, 3])[:, :, :seq, :]
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
if self.use_neox_rotary_style:
sin_pos = sin
cos_pos = cos
# NeoX Stype前后半部分分块旋转
rotate_half_q = paddle.reshape(
paddle.stack([-q[:, :, :, q.shape[-1]//2:], q[:, :, :, :q.shape[-1]//2]], axis=-1),
paddle.shape(q),
)
rotate_half_k = paddle.reshape(
paddle.stack([-k[:, :, :, k.shape[-1]//2:], k[:, :, :, :k.shape[-1]//2]], axis=-1),
paddle.shape(k),
)
else:
# import pdb;pdb.set_trace()
sin_pos = paddle.reshape(paddle.stack(
[sin, sin], axis=-1), [1, 1, seq, head_dim])
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = paddle.reshape(paddle.stack(
[cos, cos], axis=-1), [1, 1, seq, head_dim])
# GPT Stype奇偶位置分块旋转
rotate_half_q = paddle.reshape(
paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1),
paddle.shape(q),
)
rotate_half_k = paddle.reshape(
paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1),
paddle.shape(k),
)
query = paddle.add(
paddle.multiply(q, cos_pos), paddle.multiply(
rotate_half_q, sin_pos)
)
key = paddle.add(
paddle.multiply(k, cos_pos), paddle.multiply(
rotate_half_k, sin_pos)
)
return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype)
def create_attn_mask(
mask_type,
batch_size,
seq_lens,
pre_cache_length=0,
):
max_seq_len = max(seq_lens)
mask = paddle.zeros(
# [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length],
[batch_size, 1, max_seq_len, max_seq_len],
dtype=mask_type,
)
mask[:, :, :, :pre_cache_length] = 1
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type))
- 1
) * 1e4
return mask
def block_cache_to_naive_cache(
cache_k, cache_v, bsz, block_tables, cache_seq_len
):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(
shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype
)
out_cache_v = paddle.zeros(
shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype
)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[
block_tables[i, j // blocksize], :, j % blocksize, :
]
out_cache_v[i, :, j, :] = cache_v[
block_tables[i, j // blocksize], :, j % blocksize, :
]
return out_cache_k, out_cache_v
def naive_attention_impl(
query,
key,
value,
cache_k=None,
cache_v=None,
pre_cache_k=None,
pre_cache_v=None,
mask=None,
scale=1.0,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
use_cachekv_int8="None",
):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])
if cache_k is not None:
cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1])
cache_k = cache_k.reshape([batch, heads, -1, head_dim])
key = paddle.concat([cache_k, key], axis=2)
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])
if cache_v is not None:
cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1])
cache_v = cache_v.reshape([batch, heads, -1, head_dim])
value = paddle.concat([cache_v, value], axis=2)
qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
if mask is not None:
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(paddle.cast(
softmax_result, dtype=value.dtype), value)
return result
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time)
cum_offsets = paddle.zeros(shape=(bsz + 1), dtype="int32")
cum_offsets[1:] = cum_offsets_now
token_num = paddle.sum(seq_lens_this_time)
padding_offsets = paddle.zeros(shape=(token_num), dtype="int32")
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
for i in range(bsz):
seq_len_now = seq_lens_this_time[i]
cum_offset = cum_offsets[i]
for j in range(seq_len_now):
padding_offsets[i * max_seq_len - cum_offset + j] = cum_offset
cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i + 1]
cu_seqlens_q[i + 1] = cum_seq_len
cu_seqlens_k[i + 1] = cum_seq_len
return padding_offsets, cum_offsets[:-1], cu_seqlens_q, cu_seqlens_k
def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
bsz, num_head, seq_len, dim_head = inputs.shape
output = paddle.zeros(
shape=[token_num, num_head * dim_head], dtype=inputs.dtype
)
inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1])
for i in range(bsz):
seq_len_now = seq_lens[i]
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :]
return output
def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype):
query = np.random.random([bs, q_num_head, seq_len, dim_head])/10
q = paddle.to_tensor(
query, place=place, dtype=dtype, stop_gradient=False
)
key = np.random.random([bs, kv_num_head, seq_len, dim_head])/10
k = paddle.to_tensor(
key, place=place, dtype=dtype, stop_gradient=False
)
value = np.random.random([bs, kv_num_head, seq_len, dim_head])/10
v = paddle.to_tensor(
value, place=place, dtype=dtype, stop_gradient=False
)
token_num = bs*seq_len
qkv = paddle.concat(
[
q.transpose([0, 2, 1, 3]).reshape(
[token_num, q_num_head*dim_head]
),
k.transpose([0, 2, 1, 3]).reshape(
[token_num, kv_num_head*dim_head]
),
v.transpose([0, 2, 1, 3]).reshape(
[token_num, kv_num_head*dim_head]
),
],
axis=1,
).reshape([token_num, -1])
return q, k, v, qkv
def split_query_by_phase(query, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, q_dim, k_dim, v_dim):
"""
将 query 拆分为 encoder 和 decoder 的 Q/K/V。
"""
batch = seq_lens_encoder.shape[0]
max_seq = query.shape[0] // batch
# 还原 query 为 [batch, seq, dim]
total_dim = q_dim + k_dim + v_dim
query = paddle.reshape(query, [batch, max_seq, total_dim])
# 计算 mask表示该 batch 是否是 encoder/decoder
is_encoder = (seq_lens_encoder > 0).astype('bool').reshape([-1]) # [batch]
is_decoder = (seq_lens_decoder > 0).astype('bool').reshape([-1]) # [batch]
# 准备输出列表
enc_qs, enc_ks, enc_vs = [], [], []
dec_qs, dec_ks, dec_vs = [], [], []
for i in range(batch):
real_len = int(seq_lens_this_time[i]) # 当前 batch 的有效长度
cur_query = query[i, :real_len, :] # [seq_i, q+k+v]
q, k, v = paddle.split(cur_query, [q_dim, k_dim, v_dim], axis=-1)
if is_encoder[i]:
enc_qs.append(q)
enc_ks.append(k)
enc_vs.append(v)
elif is_decoder[i]:
dec_qs.append(q)
dec_ks.append(k)
dec_vs.append(v)
if enc_qs:
enc_q = paddle.concat(enc_qs, axis=0)
enc_k = paddle.concat(enc_ks, axis=0)
enc_v = paddle.concat(enc_vs, axis=0)
else:
enc_q = enc_k = enc_v = paddle.zeros([0, q_dim], dtype=query.dtype)
if dec_qs:
dec_q = paddle.concat(dec_qs, axis=0)
dec_k = paddle.concat(dec_ks, axis=0)
dec_v = paddle.concat(dec_vs, axis=0)
else:
dec_q = dec_k = dec_v = paddle.zeros([0, q_dim], dtype=query.dtype)
return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v)
class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = False
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = 'float16'
self.init_tensor()
def init_tensor(self):
self.block_num_per_seq = (
self.seq_len + self.max_dec_len + self.blocksize - 1
) // self.blocksize
self.rope = RopeEmbedding(self.use_neox_rotary_style)
self.max_block_num = self.block_num_per_seq * self.batch_size
self.free_list = list(range(self.max_block_num - 1, -1, -1))
self.seq_lens_enc = [
self.seq_len,
] * self.batch_size
self.seq_lens_dec = [
0,
] * self.batch_size
self.max_enc_len_this_time = max(self.seq_lens_enc)
self.max_dec_len_this_time = max(self.seq_lens_dec)
self.seq_lens_encoder = paddle.to_tensor(
self.seq_lens_enc,
"int32",
)
self.seq_lens_decoder = paddle.to_tensor(
self.seq_lens_dec,
"int32",
)
self.max_enc_len_this_time = paddle.to_tensor(
[self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor(
[self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_lens_this_time = self.seq_lens_encoder
self.cache_shape = (
self.max_block_num,
self.kv_num_head,
self.blocksize,
self.dim_head,
)
self.scale = 1.0 / np.sqrt(self.dim_head)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.block_tables = paddle.zeros(
shape=(self.batch_size, self.block_num_per_seq), dtype="int32"
)
for i in range(self.batch_size):
need_block_num = (
self.seq_len + self.max_dec_len + self.blocksize - 1
) // self.blocksize
for j in range(need_block_num):
self.block_tables[i, j] = self.free_list.pop()
(
self.padding_offset,
self.cum_offset,
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(
self.batch_size, self.seq_len, self.seq_lens_this_time
)
self.token_num = self.padding_offset.shape[0]
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
self.token_num = self.seq_len*self.batch_size
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
self.batch_size,
self.q_num_head,
self.kv_num_head,
self.seq_len,
self.dim_head,
self.place,
self.dtype
)
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
out_ = naive_attention_impl(
q, k, v, naive_cache_k, naive_cache_v, None, None, attn_mask, self.scale
)
out_ = remove_padding(
self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num
)
speculate_max_draft_token_num = 1
from fastdeploy.model_executor.layers.attention.ops import append_attention
from fastdeploy.model_executor.layers.attention.ops import get_block_shape_and_split_kv_block
(
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_kv,
set_max_lengths,
) = get_block_shape_and_split_kv_block(
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.cum_offset,
64,
12,
(self.q_num_head + 2*self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num+1,
)
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP+RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.padding_offset,
self.cum_offset,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
set_max_lengths,
max_len_kv,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
None, # kv_signal_data
"fp16",
"none", # cache_quant_type
self.use_neox_rotary_style,
False,
self.max_seq_len,
0.0, # quant_min_bound
0.0, # quant_max_bound
-1, # out_linear_in_scale
64, # encoder_block_shape_q
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num+1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)[0]
paddle.device.synchronize()
end_time = time.time()
print(
"[append-attn ut] cost_time:{}ms".format(
(end_time - start_time) / RUN_TIME * 1000
)
)
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
np.testing.assert_allclose(
out.numpy(),
out_.numpy(),
rtol=1e-02,
atol=1e-02,
)
def test_all(self):
tmp_position_ids = paddle.arange(
self.seq_len + self.max_dec_len
).reshape((1, -1))
# appendattn 传的是最大maxseq
if self.use_neox_rotary_style:
self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head)
else:
self.rope_emb = self.rope.get_rotary_position_embedding(
tmp_position_ids, self.dim_head
)
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
# encoder
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
self.seq_lens_this_time = self.seq_lens_encoder
self.cmp_append_attention(attn_mask=self.attention_mask)
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
# decoder
self.seq_lens_decoder[:] = self.seq_lens_encoder
self.seq_lens_encoder[:] = 0
self.seq_lens_this_time[:] = 1
self.seq_lens_enc = [
0,
] * self.batch_size
self.seq_lens_dec = [
self.seq_len,
] * self.batch_size
self.max_enc_len_this_time = max(self.seq_lens_enc)
self.max_dec_len_this_time = max(self.seq_lens_dec)
self.max_enc_len_this_time = paddle.to_tensor(
[self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor(
[self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_len = 1
(
self.padding_offset,
self.cum_offset,
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = True
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = 'float16'
self.init_tensor()
if __name__ == '__main__':
unittest.main()