qk norm for speculate decode C16 (#3637)

This commit is contained in:
Yuan Xiaolan
2025-09-03 14:53:56 +08:00
committed by GitHub
parent d22d3de256
commit fa58a9fa8f
6 changed files with 470 additions and 160 deletions

View File

@@ -5,12 +5,16 @@ import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.incubate.nn.functional import fused_rms_norm
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
np.random.seed(0)
paddle.seed(0)
class TestTreeMask(unittest.TestCase):
def setUp(self):
@@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase):
self.head_dim = 128
self.num_q_head = 20
self.num_kv_head = 4
self.use_qknorm = True
self.dtype = "bfloat16"
self.rope_3d = False
@@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase):
cu_seqlens_k[i + 1] = cum_seq_len_k
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
def ref_attention(self, q, k, v, mask):
def ref_attention(self, q, k, v, mask, use_qknorm=False):
if use_qknorm:
q = q.reshape([-1, self.head_dim])
q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim])
q = q.transpose([0, 2, 1, 3])
if len(k) > 1:
k = paddle.concat(k, axis=1)
else:
k = k[0]
if use_qknorm:
k = k.reshape([-1, self.head_dim])
k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim])
k = k.transpose([0, 2, 1, 3])
if len(v) > 1:
v = paddle.concat(v, axis=1)
@@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase):
.reshape([-1, self.num_q_head, self.head_dim])
)
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False):
if prefill:
seq_lens_enc = [
q_len,
@@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase):
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
q_norm_weight = np.ones([self.head_dim])
k_norm_weight = np.ones([self.head_dim])
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
paddle.device.synchronize()
(
encoder_batch_ids,
@@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase):
max_len_kv,
rotary_embs,
attn_mask,
None,
None,
None, # qkv_bias
None, # qkv_out_scales
cache_k_scale,
cache_v_scale,
cache_k_out_scale,
cache_v_out_scale,
None,
None,
None,
None,
None,
None,
None,
None,
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
None, # mask_offset
None, # kv_signal_data
self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight
self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight
1e-6,
"bf16",
"none",
@@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase):
paddle.device.synchronize()
e_time = time.time()
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
return out.reshape([token_num, self.num_q_head, self.head_dim])
def test_naive_speculative_decoding(self):
prefill_len = 8192
@@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase):
total_len = prefill_len + dec_len_q
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
self.run_append_c16_attention(prefill_len, 0, True)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm)
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm)
np.testing.assert_allclose(
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
)