Files
FastDeploy/tests/layers/test_append_attention.py
AIbin a7392a0ff9 【Inference Optimize】DeepSeek-V3-model MLA Optimize (#3886)
* support MLA chunk_size auto search & cuda_graph
2025-09-11 10:46:09 +08:00

760 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
import unittest
import numpy as np
import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10)
np.random.seed(10)
class RopeEmbedding:
def __init__(self, use_neox_rotary_style=False):
self.use_neox_rotary_style = use_neox_rotary_style
self.base = 10000
def get_neox_style_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim))
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
def get_rotary_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
return rot_emb
def _apply_rope(self, rotary_emb, q, k, v=None, causal=False):
# sin [sequence_length, embed_size_per_head//2]
# cos [sequence_length, embed_size_per_head//2]
# sin, cos = paddle.chunk(rp, 2, axis=-1)
seq, head_dim = q.shape[2], q.shape[3]
cos, sin = paddle.chunk(rotary_emb, 2, axis=0)
cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
if self.use_neox_rotary_style:
sin_pos = sin
cos_pos = cos
# NeoX Stype前后半部分分块旋转
rotate_half_q = paddle.reshape(
paddle.stack(
[
-q[:, :, :, q.shape[-1] // 2 :],
q[:, :, :, : q.shape[-1] // 2],
],
axis=-1,
),
paddle.shape(q),
)
rotate_half_k = paddle.reshape(
paddle.stack(
[
-k[:, :, :, k.shape[-1] // 2 :],
k[:, :, :, : k.shape[-1] // 2],
],
axis=-1,
),
paddle.shape(k),
)
else:
# import pdb;pdb.set_trace()
sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim])
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim])
# GPT Stype奇偶位置分块旋转
rotate_half_q = paddle.reshape(
paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1),
paddle.shape(q),
)
rotate_half_k = paddle.reshape(
paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1),
paddle.shape(k),
)
query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos))
key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos))
return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype)
def create_attn_mask(
mask_type,
batch_size,
seq_lens,
pre_cache_length=0,
):
max_seq_len = max(seq_lens)
mask = paddle.zeros(
# [batch_size, 1, max_seq_len, max_seq_len + pre_cache_length],
[batch_size, 1, max_seq_len, max_seq_len],
dtype=mask_type,
)
mask[:, :, :, :pre_cache_length] = 1
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) - 1
) * 1e4
return mask
def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
def naive_attention_impl(
query,
key,
value,
cache_k=None,
cache_v=None,
pre_cache_k=None,
pre_cache_v=None,
mask=None,
scale=1.0,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
use_cachekv_int8="None",
q_norm_weight=None,
k_norm_weight=None,
):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])
if cache_k is not None:
cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim])
cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1])
cache_k = cache_k.reshape([batch, heads, -1, head_dim])
key = paddle.concat([cache_k, key], axis=2)
value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])
if cache_v is not None:
cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim])
cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1])
cache_v = cache_v.reshape([batch, heads, -1, head_dim])
value = paddle.concat([cache_v, value], axis=2)
qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
if mask is not None:
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value)
return result
def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(max_seq_len - seq_lens_this_time, dtype="int32")
cum_offsets = paddle.zeros(shape=(bsz + 1), dtype="int32")
cum_offsets[1:] = cum_offsets_now
token_num = paddle.sum(seq_lens_this_time)
padding_offsets = paddle.zeros(shape=(token_num), dtype="int32")
cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32")
cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32")
for i in range(bsz):
seq_len_now = seq_lens_this_time[i]
cum_offset = cum_offsets[i]
for j in range(seq_len_now):
padding_offsets[i * max_seq_len - cum_offset + j] = cum_offset
cum_seq_len = (i + 1) * max_seq_len - cum_offsets[i + 1]
cu_seqlens_q[i + 1] = cum_seq_len
cu_seqlens_k[i + 1] = cum_seq_len
return padding_offsets, cum_offsets[:-1], cu_seqlens_q, cu_seqlens_k
def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
bsz, num_head, seq_len, dim_head = inputs.shape
output = paddle.zeros(shape=[token_num, num_head * dim_head], dtype=inputs.dtype)
inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1])
for i in range(bsz):
seq_len_now = seq_lens[i]
start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1]
output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :]
return output
def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype):
query = np.random.random([bs, q_num_head, seq_len, dim_head]) / 10
q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False)
key = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False)
value = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False)
token_num = bs * seq_len
qkv = paddle.concat(
[
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
],
axis=1,
).reshape([token_num, -1])
return q, k, v, qkv
def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
print("q:", q.shape)
print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape
_, kv_num_head, _, _ = k.shape
q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim])
print("q:", q)
q = fused_rms_norm(q.astype("float32"), q_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
print("q after norm:", q)
k = fused_rms_norm(k.astype("float32"), k_norm_weight_tensor, None, 1e-5)[0].astype(dtype)
q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
def split_query_by_phase(
query,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
q_dim,
k_dim,
v_dim,
):
"""
将 query 拆分为 encoder 和 decoder 的 Q/K/V。
"""
batch = seq_lens_encoder.shape[0]
max_seq = query.shape[0] // batch
# 还原 query 为 [batch, seq, dim]
total_dim = q_dim + k_dim + v_dim
query = paddle.reshape(query, [batch, max_seq, total_dim])
# 计算 mask表示该 batch 是否是 encoder/decoder
is_encoder = (seq_lens_encoder > 0).astype("bool").reshape([-1]) # [batch]
is_decoder = (seq_lens_decoder > 0).astype("bool").reshape([-1]) # [batch]
# 准备输出列表
enc_qs, enc_ks, enc_vs = [], [], []
dec_qs, dec_ks, dec_vs = [], [], []
for i in range(batch):
real_len = int(seq_lens_this_time[i]) # 当前 batch 的有效长度
cur_query = query[i, :real_len, :] # [seq_i, q+k+v]
q, k, v = paddle.split(cur_query, [q_dim, k_dim, v_dim], axis=-1)
if is_encoder[i]:
enc_qs.append(q)
enc_ks.append(k)
enc_vs.append(v)
elif is_decoder[i]:
dec_qs.append(q)
dec_ks.append(k)
dec_vs.append(v)
if enc_qs:
enc_q = paddle.concat(enc_qs, axis=0)
enc_k = paddle.concat(enc_ks, axis=0)
enc_v = paddle.concat(enc_vs, axis=0)
else:
enc_q = enc_k = enc_v = paddle.zeros([0, q_dim], dtype=query.dtype)
if dec_qs:
dec_q = paddle.concat(dec_qs, axis=0)
dec_k = paddle.concat(dec_ks, axis=0)
dec_v = paddle.concat(dec_vs, axis=0)
else:
dec_q = dec_k = dec_v = paddle.zeros([0, q_dim], dtype=query.dtype)
return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v)
class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = False
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = False
self.init_tensor()
def init_tensor(self):
self.block_num_per_seq = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
self.rope = RopeEmbedding(self.use_neox_rotary_style)
self.max_block_num = self.block_num_per_seq * self.batch_size
self.free_list = list(range(self.max_block_num - 1, -1, -1))
self.seq_lens_enc = [
self.seq_len,
] * self.batch_size
self.seq_lens_dec = [
0,
] * self.batch_size
self.max_enc_len_this_time = max(self.seq_lens_enc)
self.max_dec_len_this_time = max(self.seq_lens_dec)
self.seq_lens_encoder = paddle.to_tensor(
self.seq_lens_enc,
"int32",
)
self.seq_lens_decoder = paddle.to_tensor(
self.seq_lens_dec,
"int32",
)
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_lens_this_time = self.seq_lens_encoder
decode_max_tile_size = 1024 * self.batch_size * np.ceil((2 * 10) / 12)
self.decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").pin_memory()
self.decoder_num_blocks_device = paddle.full([1], 0, dtype="int32")
self.decoder_chunk_size_device = paddle.full([1], 64, dtype="int32")
self.max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
self.encoder_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.encoder_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.kv_batch_ids = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_tile_ids_per_batch = paddle.full([self.batch_size], 0, dtype="int32")
self.kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.max_len_kv_cpu = paddle.full([1], 0, dtype="int32").cpu()
self.cache_shape = (
self.max_block_num,
self.kv_num_head,
self.blocksize,
self.dim_head,
)
self.scale = 1.0 / np.sqrt(self.dim_head)
if self.use_dynamic_quant:
self.cache_scale_shape = (
self.max_block_num,
self.kv_num_head,
self.blocksize,
)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
else:
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = None
self.value_cache_scale = None
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
for i in range(self.batch_size):
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
for j in range(need_block_num):
self.block_tables[i, j] = self.free_list.pop()
(
self.padding_offset,
self.cum_offset,
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
self.token_num = self.padding_offset.shape[0]
self.mask_offset = None
if self.use_mask_offset:
self.mask_offset = paddle.full(self.batch_size * self.seq_len * 2, 0, "int32")
for i in range(self.batch_size):
for j in range(self.seq_len):
self.mask_offset[i * self.seq_len * 2 + j * 2] = 0
self.mask_offset[i * self.seq_len * 2 + j * 2 + 1] = j + 1
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
print("use_dynamic_quant: ", self.use_dynamic_quant)
self.token_num = self.seq_len * self.batch_size
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
self.batch_size,
self.q_num_head,
self.kv_num_head,
self.seq_len,
self.dim_head,
self.place,
self.dtype,
)
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
if self.use_qk_norm:
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
else:
q_norm_weight = None
k_norm_weight = None
out_ = naive_attention_impl(
q,
k,
v,
naive_cache_k,
naive_cache_v,
None,
None,
attn_mask,
self.scale,
)
out_ = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num)
speculate_max_draft_token_num = 1
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
get_block_shape_and_split_kv_block(
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.decoder_num_blocks_device,
self.decoder_chunk_size_device,
self.max_len_tensor_cpu,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.max_len_kv_cpu,
64,
12,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num + 1,
)
if self.use_dynamic_quant:
cache_quant_type = "block_wise_fp8"
else:
cache_quant_type = "none"
if self.use_dynamic_quant:
qkv_copy = copy.deepcopy(qkv)
append_attention(
qkv_copy,
self.cache_k_T,
self.cache_v_T,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.padding_offset,
self.cum_offset,
self.block_tables,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
self.max_len_kv_cpu,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
None, # cache_k_quant_scales
None, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
"none",
self.use_neox_rotary_style,
False,
self.max_seq_len,
0.0, # quant_min_bound
0.0, # quant_max_bound
-1, # out_linear_in_scale
64, # encoder_block_shape_q
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.padding_offset,
self.cum_offset,
self.block_tables,
self.encoder_batch_ids,
self.encoder_tile_ids_per_batch,
self.encoder_num_blocks_x_cpu,
self.kv_batch_ids,
self.kv_tile_ids_per_batch,
self.kv_num_blocks_x_cpu,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
self.max_len_kv_cpu,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
self.key_cache_scale, # cache_k_quant_scales
self.value_cache_scale, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
cache_quant_type,
self.use_neox_rotary_style,
False,
self.max_seq_len,
0.0, # quant_min_bound
0.0, # quant_max_bound
-1, # out_linear_in_scale
64, # encoder_block_shape_q
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)
paddle.device.synchronize()
end_time = time.time()
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
np.testing.assert_allclose(
out.numpy(),
out_.numpy(),
rtol=1e-02,
atol=1e-02,
)
def test_all(self):
tmp_position_ids = paddle.arange(self.seq_len + self.max_dec_len).reshape((1, -1))
# appendattn 传的是最大maxseq
if self.use_neox_rotary_style:
self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head)
else:
self.rope_emb = self.rope.get_rotary_position_embedding(tmp_position_ids, self.dim_head)
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
[
self.seq_len,
]
* self.batch_size,
)
# encoder
# self.seq_lens_encoder,self.seq_lens_decoder,self.max_enc_len_this_time,self.max_dec_len_this_time=get_encoder_decoder_len(self.batch_size,self.seq_len)
self.seq_lens_this_time = self.seq_lens_encoder
if self.use_mask_offset:
print("encoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(attn_mask=self.attention_mask)
if self.use_dynamic_quant:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k_T,
self.cache_v_T,
self.batch_size,
self.block_tables,
self.seq_len,
)
else:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
# decoder
self.seq_lens_decoder[:] = self.seq_lens_encoder
self.seq_lens_encoder[:] = 0
self.seq_lens_this_time[:] = 1
self.seq_lens_enc = [
0,
] * self.batch_size
self.seq_lens_dec = [
self.seq_len,
] * self.batch_size
self.max_enc_len_this_time = max(self.seq_lens_enc)
self.max_dec_len_this_time = max(self.seq_lens_dec)
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_len = 1
(
self.padding_offset,
self.cum_offset,
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
if self.use_mask_offset:
self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32")
for i in range(self.batch_size):
self.mask_offset[i * 2] = 0
self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1
print("decoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithNeoXRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = True
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = False
self.use_mask_offset = True
self.use_dynamic_quant = False
self.init_tensor()
class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = False
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = True
self.init_tensor()
if __name__ == "__main__":
unittest.main()