support qk norm (#3145)

This commit is contained in:
Yuan Xiaolan
2025-08-05 16:46:14 +08:00
committed by GitHub
parent 4a10e29804
commit 7ce00e597c
17 changed files with 791 additions and 201 deletions

View File

@@ -17,6 +17,7 @@ import unittest
import numpy as np
import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10)
@@ -157,6 +158,8 @@ def naive_attention_impl(
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
use_cachekv_int8="None",
q_norm_weight=None,
k_norm_weight=None,
):
batch = query.shape[0]
heads = query.shape[1]
@@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
return q, k, v, qkv
def apply_qk_norm(head_dim, dtype, q, k):
q_norm_weight = np.random.random([head_dim]) / 10
k_norm_weight = np.random.random([head_dim]) / 10
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
print("q:", q.shape)
print("k:", k.shape)
bs, q_num_head, seq_len, dim_head = q.shape
_, kv_num_head, _, _ = k.shape
q = q.reshape([-1, head_dim])
k = k.reshape([-1, head_dim])
print("q:", q)
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
print("q after norm:", q)
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
q = q.reshape([-1, q_num_head, seq_len, dim_head])
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
def split_query_by_phase(
query,
seq_lens_encoder,
@@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = True
self.init_tensor()
def init_tensor(self):
@@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
)
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
if self.use_qk_norm:
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
else:
q_norm_weight = None
k_norm_weight = None
out_ = naive_attention_impl(
q,
k,
@@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
None, # linear_shift
None, # linear_smooth
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
"none", # cache_quant_type
self.use_neox_rotary_style,
@@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.use_qk_norm = False
self.init_tensor()