mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
support qk norm (#3145)
This commit is contained in:
@@ -17,6 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.incubate.nn.functional import fused_rms_norm
|
||||
|
||||
paddle.seed(10)
|
||||
|
||||
@@ -157,6 +158,8 @@ def naive_attention_impl(
|
||||
cache_k_dequant_scales=None,
|
||||
cache_v_dequant_scales=None,
|
||||
use_cachekv_int8="None",
|
||||
q_norm_weight=None,
|
||||
k_norm_weight=None,
|
||||
):
|
||||
batch = query.shape[0]
|
||||
heads = query.shape[1]
|
||||
@@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
|
||||
return q, k, v, qkv
|
||||
|
||||
|
||||
def apply_qk_norm(head_dim, dtype, q, k):
|
||||
q_norm_weight = np.random.random([head_dim]) / 10
|
||||
k_norm_weight = np.random.random([head_dim]) / 10
|
||||
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
||||
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
||||
print("q:", q.shape)
|
||||
print("k:", k.shape)
|
||||
bs, q_num_head, seq_len, dim_head = q.shape
|
||||
_, kv_num_head, _, _ = k.shape
|
||||
|
||||
q = q.reshape([-1, head_dim])
|
||||
k = k.reshape([-1, head_dim])
|
||||
print("q:", q)
|
||||
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
||||
print("q after norm:", q)
|
||||
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
||||
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||
|
||||
|
||||
def split_query_by_phase(
|
||||
query,
|
||||
seq_lens_encoder,
|
||||
@@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.softmax_scale = self.dim_head**-0.5
|
||||
self.rope_theta = 10000
|
||||
self.dtype = "float16"
|
||||
self.use_qk_norm = True
|
||||
self.init_tensor()
|
||||
|
||||
def init_tensor(self):
|
||||
@@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
)
|
||||
|
||||
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
|
||||
if self.use_qk_norm:
|
||||
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
|
||||
else:
|
||||
q_norm_weight = None
|
||||
k_norm_weight = None
|
||||
out_ = naive_attention_impl(
|
||||
q,
|
||||
k,
|
||||
@@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
None, # linear_shift
|
||||
None, # linear_smooth
|
||||
None, # kv_signal_data
|
||||
q_norm_weight, # q_norm_weight
|
||||
k_norm_weight, # k_norm_weight
|
||||
1e-6,
|
||||
"fp16",
|
||||
"none", # cache_quant_type
|
||||
self.use_neox_rotary_style,
|
||||
@@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
||||
self.softmax_scale = self.dim_head**-0.5
|
||||
self.rope_theta = 10000
|
||||
self.dtype = "float16"
|
||||
self.use_qk_norm = False
|
||||
self.init_tensor()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user