mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 12:52:29 +08:00
760 lines
29 KiB
Python
760 lines
29 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import copy
|
||
import time
|
||
import unittest
|
||
|
||
import numpy as np
|
||
import paddle
|
||
from paddle.incubate.nn.functional import fused_rms_norm
|
||
|
||
paddle.seed(10)
|
||
np.random.seed(10)
|
||
|
||
|
||
class RopeEmbedding:
|
||
def __init__(self, use_neox_rotary_style=False):
|
||
self.use_neox_rotary_style = use_neox_rotary_style
|
||
self.base = 10000
|
||
|
||
def get_neox_style_position_embedding(self, position_ids, head_dim):
|
||
bsz, max_seq_len = position_ids.shape[:2]
|
||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
|
||
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
|
||
|
||
# shape: [B, S, D/2]
|
||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||
# shape: [B, S, 1, D]
|
||
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim))
|
||
|
||
rot_emb[0] = paddle.cos(emb)
|
||
rot_emb[1] = paddle.sin(emb)
|
||
return rot_emb
|
||
|
||
def get_rotary_position_embedding(self, position_ids, head_dim):
|
||
bsz, max_seq_len = position_ids.shape[:2]
|
||
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32")
|
||
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
|
||
|
||
# shape: [B, S, D/2]
|
||
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
|
||
# shape: [B, S, D/2]
|
||
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2))
|
||
# shape: [B, S, 1, D]
|
||
emb = paddle.unsqueeze(emb, 2)
|
||
|
||
rot_emb[0] = paddle.cos(emb)
|
||
rot_emb[1] = paddle.sin(emb)
|
||
return rot_emb
|
||
|
||
def _apply_rope(self, rotary_emb, q, k, v=None, causal=False):
|
||
# sin [sequence_length, embed_size_per_head//2]
|
||
# cos [sequence_length, embed_size_per_head//2]
|
||
# sin, cos = paddle.chunk(rp, 2, axis=-1)
|
||
seq, head_dim = q.shape[2], q.shape[3]
|
||
cos, sin = paddle.chunk(rotary_emb, 2, axis=0)
|
||
cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
|
||
sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
|
||
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||
|
||
if self.use_neox_rotary_style:
|
||
sin_pos = sin
|
||
cos_pos = cos
|
||
# NeoX Stype:前后半部分分块旋转
|
||
rotate_half_q = paddle.reshape(
|
||
paddle.stack(
|
||
[
|
||
-q[:, :, :, q.shape[-1] // 2 :],
|
||
q[:, :, :, : q.shape[-1] // 2],
|
||
],
|
||
axis=-1,
|
||
),
|
||
paddle.shape(q),
|
||
)
|
||
rotate_half_k = paddle.reshape(
|
||
paddle.stack(
|
||
[
|
||
-k[:, :, :, k.shape[-1] // 2 :],
|
||
k[:, :, :, : k.shape[-1] // 2],
|
||
],
|
||
axis=-1,
|
||
),
|
||
paddle.shape(k),
|
||
)
|
||
else:
|
||
# import pdb;pdb.set_trace()
|
||
sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim])
|
||
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
||
cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim])
|
||
# GPT Stype:奇偶位置分块旋转
|
||
rotate_half_q = paddle.reshape(
|
||
paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1),
|
||
paddle.shape(q),
|
||
)
|
||
rotate_half_k = paddle.reshape(
|
||
paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1),
|
||
paddle.shape(k),
|
||
)
|
||
|
||
query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos))
|
||
|
||
key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos))
|
||
|
||
return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype)
|
||
|
||
|
||
def create_attn_mask(
|
||
mask_type,
|
||
batch_size,
|
||
seq_lens,
|
||
pre_cache_length=0,
|
||
):
|
||
max_seq_len = max(seq_lens)
|
||
mask = paddle.zeros(
|
||
# [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length],
|
||
[batch_size, 1, max_seq_len, max_seq_len],
|
||
dtype=mask_type,
|
||
)
|
||
mask[:, :, :, :pre_cache_length] = 1
|
||
for i in range(batch_size):
|
||
seq_len = seq_lens[i]
|
||
mask[i, 0, :seq_len, :seq_len] = (
|
||
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) - 1
|
||
) * 1e4
|
||
return mask
|
||
|
||
|
||
def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len):
|
||
_, num_head, blocksize, dim_head = cache_k.shape
|
||
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
|
||
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
|
||
for i in range(bsz):
|
||
for j in range(cache_seq_len):
|
||
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
|
||
return out_cache_k, out_cache_v
|
||
|
||
|
||
def naive_attention_impl(
|
||
query,
|
||
key,
|
||
value,
|
||
cache_k=None,
|
||
cache_v=None,
|
||
pre_cache_k=None,
|
||
pre_cache_v=None,
|
||
mask=None,
|
||
scale=1.0,
|
||
cache_k_dequant_scales=None,
|
||
cache_v_dequant_scales=None,
|
||
use_cachekv_int8="None",
|
||
q_norm_weight=None,
|
||
k_norm_weight=None,
|
||
):
|
||
batch = query.shape[0]
|
||
heads = query.shape[1]
|
||
seq_len = query.shape[2]
|
||
head_dim = query.shape[3]
|
||
kv_head = key.shape[1]
|
||
|
||
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
|
||
key = key.reshape([batch, heads, seq_len, head_dim])
|
||
|
||
if cache_k is not None:
|
||
cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
|
||
cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1])
|
||
cache_k = cache_k.reshape([batch, heads, -1, head_dim])
|
||
key = paddle.concat([cache_k, key], axis=2)
|
||
|
||
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
|
||
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
|
||
value = value.reshape([batch, heads, seq_len, head_dim])
|
||
|
||
if cache_v is not None:
|
||
cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
|
||
cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1])
|
||
cache_v = cache_v.reshape([batch, heads, -1, head_dim])
|
||
value = paddle.concat([cache_v, value], axis=2)
|
||
|
||
qk_res = paddle.matmul(query, key, transpose_y=True)
|
||
attention = qk_res * scale
|
||
if mask is not None:
|
||
attention = attention + mask
|
||
softmax_result = paddle.nn.functional.softmax(attention, -1)
|
||
result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value)
|
||
return result
|
||
|
||
|
||
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
|
||
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
|
||
cum_offsets = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||
cum_offsets[1:] = cum_offsets_now
|
||
token_num = paddle.sum(seq_lens_this_time)
|
||
padding_offsets = paddle.zeros(shape=(token_num), dtype="int32")
|
||
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
|
||
for i in range(bsz):
|
||
seq_len_now = seq_lens_this_time[i]
|
||
cum_offset = cum_offsets[i]
|
||
for j in range(seq_len_now):
|
||
padding_offsets[i * max_seq_len - cum_offset + j] = cum_offset
|
||
cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i + 1]
|
||
cu_seqlens_q[i + 1] = cum_seq_len
|
||
cu_seqlens_k[i + 1] = cum_seq_len
|
||
return padding_offsets, cum_offsets[:-1], cu_seqlens_q, cu_seqlens_k
|
||
|
||
|
||
def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
|
||
bsz, num_head, seq_len, dim_head = inputs.shape
|
||
output = paddle.zeros(shape=[token_num, num_head * dim_head], dtype=inputs.dtype)
|
||
inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1])
|
||
for i in range(bsz):
|
||
seq_len_now = seq_lens[i]
|
||
start_idx = cu_seq_lens[i]
|
||
end_idx = cu_seq_lens[i + 1]
|
||
output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :]
|
||
return output
|
||
|
||
|
||
def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype):
|
||
query = np.random.random([bs, q_num_head, seq_len, dim_head]) / 10
|
||
q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False)
|
||
key = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
|
||
k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False)
|
||
value = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
|
||
v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False)
|
||
token_num = bs * seq_len
|
||
|
||
qkv = paddle.concat(
|
||
[
|
||
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
|
||
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
|
||
],
|
||
axis=1,
|
||
).reshape([token_num, -1])
|
||
return q, k, v, qkv
|
||
|
||
|
||
def apply_qk_norm(head_dim, dtype, q, k):
|
||
q_norm_weight = np.random.random([head_dim]) / 10
|
||
k_norm_weight = np.random.random([head_dim]) / 10
|
||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||
print("q:", q.shape)
|
||
print("k:", k.shape)
|
||
bs, q_num_head, seq_len, dim_head = q.shape
|
||
_, kv_num_head, _, _ = k.shape
|
||
|
||
q = q.reshape([-1, head_dim])
|
||
k = k.reshape([-1, head_dim])
|
||
print("q:", q)
|
||
q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||
print("q after norm:", q)
|
||
k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
|
||
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||
|
||
|
||
def split_query_by_phase(
|
||
query,
|
||
seq_lens_encoder,
|
||
seq_lens_decoder,
|
||
seq_lens_this_time,
|
||
q_dim,
|
||
k_dim,
|
||
v_dim,
|
||
):
|
||
"""
|
||
将 query 拆分为 encoder 和 decoder 的 Q/K/V。
|
||
"""
|
||
|
||
batch = seq_lens_encoder.shape[0]
|
||
max_seq = query.shape[0] // batch
|
||
|
||
# 还原 query 为 [batch, seq, dim]
|
||
total_dim = q_dim + k_dim + v_dim
|
||
query = paddle.reshape(query, [batch, max_seq, total_dim])
|
||
|
||
# 计算 mask,表示该 batch 是否是 encoder/decoder
|
||
is_encoder = (seq_lens_encoder > 0).astype("bool").reshape([-1]) # [batch]
|
||
is_decoder = (seq_lens_decoder > 0).astype("bool").reshape([-1]) # [batch]
|
||
|
||
# 准备输出列表
|
||
enc_qs, enc_ks, enc_vs = [], [], []
|
||
dec_qs, dec_ks, dec_vs = [], [], []
|
||
|
||
for i in range(batch):
|
||
real_len = int(seq_lens_this_time[i]) # 当前 batch 的有效长度
|
||
cur_query = query[i, :real_len, :] # [seq_i, q+k+v]
|
||
|
||
q, k, v = paddle.split(cur_query, [q_dim, k_dim, v_dim], axis=-1)
|
||
|
||
if is_encoder[i]:
|
||
enc_qs.append(q)
|
||
enc_ks.append(k)
|
||
enc_vs.append(v)
|
||
elif is_decoder[i]:
|
||
dec_qs.append(q)
|
||
dec_ks.append(k)
|
||
dec_vs.append(v)
|
||
|
||
if enc_qs:
|
||
enc_q = paddle.concat(enc_qs, axis=0)
|
||
enc_k = paddle.concat(enc_ks, axis=0)
|
||
enc_v = paddle.concat(enc_vs, axis=0)
|
||
else:
|
||
enc_q = enc_k = enc_v = paddle.zeros([0, q_dim], dtype=query.dtype)
|
||
|
||
if dec_qs:
|
||
dec_q = paddle.concat(dec_qs, axis=0)
|
||
dec_k = paddle.concat(dec_ks, axis=0)
|
||
dec_v = paddle.concat(dec_vs, axis=0)
|
||
else:
|
||
dec_q = dec_k = dec_v = paddle.zeros([0, q_dim], dtype=query.dtype)
|
||
|
||
return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v)
|
||
|
||
|
||
class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||
def setUp(self):
|
||
paddle.disable_static()
|
||
self.name = "TestAppendGroupQueryAttnWithRope"
|
||
self.place = paddle.CUDAPlace(0)
|
||
self.batch_size = 1
|
||
self.q_num_head = 16
|
||
self.kv_num_head = 2
|
||
self.seq_len = 64
|
||
self.max_dec_len = 64
|
||
self.dim_head = 128
|
||
self.q_hid_dim = self.q_num_head * self.dim_head
|
||
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||
self.blocksize = 64
|
||
self.use_neox_rotary_style = False
|
||
# max_seq_len = self.seq_len + self.max_dec_len
|
||
self.max_seq_len = self.seq_len + self.max_dec_len
|
||
self.softmax_scale = self.dim_head**-0.5
|
||
self.rope_theta = 10000
|
||
self.dtype = "bfloat16"
|
||
self.use_qk_norm = True
|
||
self.use_mask_offset = False
|
||
self.use_dynamic_quant = False
|
||
self.init_tensor()
|
||
|
||
def init_tensor(self):
|
||
self.block_num_per_seq = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
|
||
self.rope = RopeEmbedding(self.use_neox_rotary_style)
|
||
self.max_block_num = self.block_num_per_seq * self.batch_size
|
||
self.free_list = list(range(self.max_block_num - 1, -1, -1))
|
||
|
||
self.seq_lens_enc = [
|
||
self.seq_len,
|
||
] * self.batch_size
|
||
self.seq_lens_dec = [
|
||
0,
|
||
] * self.batch_size
|
||
self.max_enc_len_this_time = max(self.seq_lens_enc)
|
||
self.max_dec_len_this_time = max(self.seq_lens_dec)
|
||
self.seq_lens_encoder = paddle.to_tensor(
|
||
self.seq_lens_enc,
|
||
"int32",
|
||
)
|
||
self.seq_lens_decoder = paddle.to_tensor(
|
||
self.seq_lens_dec,
|
||
"int32",
|
||
)
|
||
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
|
||
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||
self.seq_lens_this_time = copy.deepcopy(self.seq_lens_encoder)
|
||
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
|
||
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
|
||
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
|
||
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
|
||
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||
|
||
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
|
||
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
|
||
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
|
||
|
||
self.cache_shape = (
|
||
self.max_block_num,
|
||
self.kv_num_head,
|
||
self.blocksize,
|
||
self.dim_head,
|
||
)
|
||
|
||
self.scale = 1.0 / np.sqrt(self.dim_head)
|
||
if self.use_dynamic_quant:
|
||
self.cache_scale_shape = (
|
||
self.max_block_num,
|
||
self.kv_num_head,
|
||
self.blocksize,
|
||
)
|
||
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8")
|
||
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8")
|
||
self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||
self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||
self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
|
||
self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
|
||
else:
|
||
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||
self.key_cache_scale = None
|
||
self.value_cache_scale = None
|
||
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
|
||
for i in range(self.batch_size):
|
||
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
|
||
for j in range(need_block_num):
|
||
self.block_tables[i, j] = self.free_list.pop()
|
||
(
|
||
self.padding_offset,
|
||
self.cum_offset,
|
||
self.cu_seqlens_q,
|
||
self.cu_seqlens_k,
|
||
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
|
||
self.token_num = self.padding_offset.shape[0]
|
||
self.mask_offset = None
|
||
if self.use_mask_offset:
|
||
self.mask_offset = paddle.full(self.batch_size * self.seq_len * 2, 0, "int32")
|
||
for i in range(self.batch_size):
|
||
for j in range(self.seq_len):
|
||
self.mask_offset[i * self.seq_len * 2 + j * 2] = 0
|
||
self.mask_offset[i * self.seq_len * 2 + j * 2 + 1] = j + 1
|
||
|
||
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
|
||
paddle.disable_static()
|
||
print("use_dynamic_quant: ", self.use_dynamic_quant)
|
||
self.token_num = self.seq_len * self.batch_size
|
||
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
|
||
self.batch_size,
|
||
self.q_num_head,
|
||
self.kv_num_head,
|
||
self.seq_len,
|
||
self.dim_head,
|
||
self.place,
|
||
self.dtype,
|
||
)
|
||
|
||
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
|
||
if self.use_qk_norm:
|
||
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
|
||
else:
|
||
q_norm_weight = None
|
||
k_norm_weight = None
|
||
out_ = naive_attention_impl(
|
||
q,
|
||
k,
|
||
v,
|
||
naive_cache_k,
|
||
naive_cache_v,
|
||
None,
|
||
None,
|
||
attn_mask,
|
||
self.scale,
|
||
)
|
||
out_ = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num)
|
||
speculate_max_draft_token_num = 1
|
||
from fastdeploy.model_executor.layers.attention.ops import (
|
||
append_attention,
|
||
get_block_shape_and_split_kv_block,
|
||
)
|
||
|
||
get_block_shape_and_split_kv_block(
|
||
self.seq_lens_encoder,
|
||
self.seq_lens_decoder,
|
||
self.seq_lens_this_time,
|
||
self.decoder_batch_ids,
|
||
self.decoder_tile_ids_per_batch,
|
||
self.decoder_num_blocks_cpu,
|
||
self.decoder_num_blocks_device,
|
||
self.decoder_chunk_size_device,
|
||
self.max_len_tensor_cpu,
|
||
self.encoder_batch_ids,
|
||
self.encoder_tile_ids_per_batch,
|
||
self.encoder_num_blocks_x_cpu,
|
||
self.kv_batch_ids,
|
||
self.kv_tile_ids_per_batch,
|
||
self.kv_num_blocks_x_cpu,
|
||
self.max_len_kv_cpu,
|
||
64,
|
||
12,
|
||
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
|
||
self.blocksize,
|
||
speculate_max_draft_token_num + 1,
|
||
)
|
||
if self.use_dynamic_quant:
|
||
cache_quant_type = "block_wise_fp8"
|
||
else:
|
||
cache_quant_type = "none"
|
||
|
||
if self.use_dynamic_quant:
|
||
qkv_copy = copy.deepcopy(qkv)
|
||
append_attention(
|
||
qkv_copy,
|
||
self.cache_k_T,
|
||
self.cache_v_T,
|
||
self.seq_lens_encoder,
|
||
self.seq_lens_decoder,
|
||
self.seq_lens_this_time,
|
||
self.padding_offset,
|
||
self.cum_offset,
|
||
self.block_tables,
|
||
self.encoder_batch_ids,
|
||
self.encoder_tile_ids_per_batch,
|
||
self.encoder_num_blocks_x_cpu,
|
||
self.kv_batch_ids,
|
||
self.kv_tile_ids_per_batch,
|
||
self.kv_num_blocks_x_cpu,
|
||
self.decoder_batch_ids,
|
||
self.decoder_tile_ids_per_batch,
|
||
self.decoder_num_blocks_cpu,
|
||
self.max_len_tensor_cpu,
|
||
self.max_len_kv_cpu,
|
||
self.rope_emb, # rope_emb
|
||
None, # attn_mask
|
||
None, # qkv_bias
|
||
None, # qkv_out_scales
|
||
None, # cache_k_quant_scales
|
||
None, # cache_v_quant_scales
|
||
None, # cache_k_dequant_scales
|
||
None, # cache_v_dequant_scales
|
||
None, # cache_k_zp
|
||
None, # cache_v_zp
|
||
None, # linear_shift
|
||
None, # linear_smooth
|
||
self.mask_offset, # mask_offset
|
||
None, # kv_signal_data
|
||
q_norm_weight, # q_norm_weight
|
||
k_norm_weight, # k_norm_weight
|
||
1e-6,
|
||
"fp16",
|
||
"none",
|
||
self.use_neox_rotary_style,
|
||
False,
|
||
self.max_seq_len,
|
||
0.0, # quant_min_bound
|
||
0.0, # quant_max_bound
|
||
-1, # out_linear_in_scale
|
||
64, # encoder_block_shape_q
|
||
16, # decoder_block_shape_q
|
||
32768, # max_partition_size
|
||
32768, # encoder_max_partition_size
|
||
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
||
True, # causal
|
||
False, # speculate_decoder
|
||
)
|
||
|
||
# Warm up
|
||
WARM_UP = 1
|
||
RUN_TIME = 2
|
||
for i in range(WARM_UP + RUN_TIME):
|
||
if i == WARM_UP:
|
||
paddle.device.synchronize()
|
||
start_time = time.time()
|
||
out = append_attention(
|
||
qkv,
|
||
self.cache_k,
|
||
self.cache_v,
|
||
self.seq_lens_encoder,
|
||
self.seq_lens_decoder,
|
||
self.seq_lens_this_time,
|
||
self.padding_offset,
|
||
self.cum_offset,
|
||
self.block_tables,
|
||
self.encoder_batch_ids,
|
||
self.encoder_tile_ids_per_batch,
|
||
self.encoder_num_blocks_x_cpu,
|
||
self.kv_batch_ids,
|
||
self.kv_tile_ids_per_batch,
|
||
self.kv_num_blocks_x_cpu,
|
||
self.decoder_batch_ids,
|
||
self.decoder_tile_ids_per_batch,
|
||
self.decoder_num_blocks_cpu,
|
||
self.max_len_tensor_cpu,
|
||
self.max_len_kv_cpu,
|
||
self.rope_emb, # rope_emb
|
||
None, # attn_mask
|
||
None, # qkv_bias
|
||
None, # qkv_out_scales
|
||
self.key_cache_scale, # cache_k_quant_scales
|
||
self.value_cache_scale, # cache_v_quant_scales
|
||
None, # cache_k_dequant_scales
|
||
None, # cache_v_dequant_scales
|
||
None, # cache_k_zp
|
||
None, # cache_v_zp
|
||
None, # linear_shift
|
||
None, # linear_smooth
|
||
self.mask_offset, # mask_offset
|
||
None, # kv_signal_data
|
||
q_norm_weight, # q_norm_weight
|
||
k_norm_weight, # k_norm_weight
|
||
1e-6,
|
||
"fp16",
|
||
cache_quant_type,
|
||
self.use_neox_rotary_style,
|
||
False,
|
||
self.max_seq_len,
|
||
0.0, # quant_min_bound
|
||
0.0, # quant_max_bound
|
||
-1, # out_linear_in_scale
|
||
64, # encoder_block_shape_q
|
||
16, # decoder_block_shape_q
|
||
32768, # max_partition_size
|
||
32768, # encoder_max_partition_size
|
||
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
||
True, # causal
|
||
False, # speculate_decoder
|
||
)
|
||
paddle.device.synchronize()
|
||
end_time = time.time()
|
||
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
|
||
np.testing.assert_allclose(
|
||
out.numpy(),
|
||
out_.numpy(),
|
||
rtol=1e-02,
|
||
atol=1e-02,
|
||
)
|
||
|
||
def test_all(self):
|
||
tmp_position_ids = paddle.arange(self.seq_len + self.max_dec_len).reshape((1, -1))
|
||
# appendattn 传的是最大maxseq
|
||
if self.use_neox_rotary_style:
|
||
self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head)
|
||
else:
|
||
self.rope_emb = self.rope.get_rotary_position_embedding(tmp_position_ids, self.dim_head)
|
||
self.attention_mask = create_attn_mask(
|
||
self.dtype,
|
||
self.batch_size,
|
||
[
|
||
self.seq_len,
|
||
]
|
||
* self.batch_size,
|
||
)
|
||
# encoder
|
||
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
|
||
self.seq_lens_this_time = copy.deepcopy(self.seq_lens_encoder)
|
||
if self.use_mask_offset:
|
||
print("encoder mask_offset: ", self.mask_offset)
|
||
self.cmp_append_attention(attn_mask=self.attention_mask)
|
||
if self.use_dynamic_quant:
|
||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||
self.cache_k_T,
|
||
self.cache_v_T,
|
||
self.batch_size,
|
||
self.block_tables,
|
||
self.seq_len,
|
||
)
|
||
else:
|
||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||
self.cache_k,
|
||
self.cache_v,
|
||
self.batch_size,
|
||
self.block_tables,
|
||
self.seq_len,
|
||
)
|
||
# decoder
|
||
self.seq_lens_decoder[:] = self.seq_lens_encoder
|
||
self.seq_lens_encoder[:] = 0
|
||
self.seq_lens_this_time[:] = 1
|
||
self.seq_lens_enc = [
|
||
0,
|
||
] * self.batch_size
|
||
self.seq_lens_dec = [
|
||
self.seq_len,
|
||
] * self.batch_size
|
||
self.max_enc_len_this_time = max(self.seq_lens_enc)
|
||
self.max_dec_len_this_time = max(self.seq_lens_dec)
|
||
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
|
||
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
|
||
|
||
self.seq_len = 1
|
||
(
|
||
self.padding_offset,
|
||
self.cum_offset,
|
||
self.cu_seqlens_q,
|
||
self.cu_seqlens_k,
|
||
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
|
||
if self.use_mask_offset:
|
||
self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32")
|
||
for i in range(self.batch_size):
|
||
self.mask_offset[i * 2] = 0
|
||
self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1
|
||
print("decoder mask_offset: ", self.mask_offset)
|
||
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
|
||
|
||
|
||
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
||
def setUp(self):
|
||
paddle.disable_static()
|
||
self.name = "TestAppendGroupQueryAttnWithNeoXRope"
|
||
self.place = paddle.CUDAPlace(0)
|
||
self.batch_size = 1
|
||
self.q_num_head = 16
|
||
self.kv_num_head = 2
|
||
self.seq_len = 64
|
||
self.max_dec_len = 64
|
||
self.dim_head = 128
|
||
self.q_hid_dim = self.q_num_head * self.dim_head
|
||
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||
self.blocksize = 64
|
||
self.use_neox_rotary_style = True
|
||
# max_seq_len = self.seq_len + self.max_dec_len
|
||
self.max_seq_len = self.seq_len + self.max_dec_len
|
||
self.softmax_scale = self.dim_head**-0.5
|
||
self.rope_theta = 10000
|
||
self.dtype = "float16"
|
||
self.use_qk_norm = False
|
||
self.use_mask_offset = True
|
||
self.use_dynamic_quant = False
|
||
self.init_tensor()
|
||
|
||
|
||
class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope):
|
||
def setUp(self):
|
||
paddle.disable_static()
|
||
self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8"
|
||
self.place = paddle.CUDAPlace(0)
|
||
self.batch_size = 1
|
||
self.q_num_head = 16
|
||
self.kv_num_head = 2
|
||
self.seq_len = 64
|
||
self.max_dec_len = 64
|
||
self.dim_head = 128
|
||
self.q_hid_dim = self.q_num_head * self.dim_head
|
||
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||
self.blocksize = 64
|
||
self.use_neox_rotary_style = False
|
||
# max_seq_len = self.seq_len + self.max_dec_len
|
||
self.max_seq_len = self.seq_len + self.max_dec_len
|
||
self.softmax_scale = self.dim_head**-0.5
|
||
self.rope_theta = 10000
|
||
self.dtype = "bfloat16"
|
||
self.use_qk_norm = True
|
||
self.use_mask_offset = False
|
||
self.use_dynamic_quant = True
|
||
self.init_tensor()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
unittest.main()
|