mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-16 13:41:30 +08:00
qk norm for speculate decode C16 (#3637)
This commit is contained in:
@@ -5,12 +5,16 @@ import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle.incubate.nn.functional import fused_rms_norm
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
append_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
)
|
||||
|
||||
np.random.seed(0)
|
||||
paddle.seed(0)
|
||||
|
||||
|
||||
class TestTreeMask(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -27,6 +31,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
self.head_dim = 128
|
||||
self.num_q_head = 20
|
||||
self.num_kv_head = 4
|
||||
self.use_qknorm = True
|
||||
self.dtype = "bfloat16"
|
||||
|
||||
self.rope_3d = False
|
||||
@@ -91,12 +96,20 @@ class TestTreeMask(unittest.TestCase):
|
||||
cu_seqlens_k[i + 1] = cum_seq_len_k
|
||||
return paddle.to_tensor(batch_id_per_token, dtype="int32"), cu_seqlens_q, cu_seqlens_k
|
||||
|
||||
def ref_attention(self, q, k, v, mask):
|
||||
def ref_attention(self, q, k, v, mask, use_qknorm=False):
|
||||
if use_qknorm:
|
||||
q = q.reshape([-1, self.head_dim])
|
||||
q = fused_rms_norm(q.astype("float32"), self.q_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
|
||||
q = q.reshape([self.bsz, -1, self.num_q_head, self.head_dim])
|
||||
q = q.transpose([0, 2, 1, 3])
|
||||
if len(k) > 1:
|
||||
k = paddle.concat(k, axis=1)
|
||||
else:
|
||||
k = k[0]
|
||||
if use_qknorm:
|
||||
k = k.reshape([-1, self.head_dim])
|
||||
k = fused_rms_norm(k.astype("float32"), self.k_norm_weight_tensor, None, 1e-6)[0].astype(self.dtype)
|
||||
k = k.reshape([self.bsz, -1, self.num_kv_head, self.head_dim])
|
||||
k = k.transpose([0, 2, 1, 3])
|
||||
if len(v) > 1:
|
||||
v = paddle.concat(v, axis=1)
|
||||
@@ -127,7 +140,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
.reshape([-1, self.num_q_head, self.head_dim])
|
||||
)
|
||||
|
||||
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None):
|
||||
def run_append_c16_attention(self, q_len, kv_len, prefill=False, attn_mask=None, use_qknorm=False):
|
||||
if prefill:
|
||||
seq_lens_enc = [
|
||||
q_len,
|
||||
@@ -187,6 +200,10 @@ class TestTreeMask(unittest.TestCase):
|
||||
decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
decoder_num_blocks = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
max_len_tensor_cpu = paddle.full([8], 0, dtype="int32").cpu()
|
||||
q_norm_weight = np.ones([self.head_dim])
|
||||
k_norm_weight = np.ones([self.head_dim])
|
||||
self.q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype="float32")
|
||||
self.k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype="float32")
|
||||
paddle.device.synchronize()
|
||||
(
|
||||
encoder_batch_ids,
|
||||
@@ -237,20 +254,20 @@ class TestTreeMask(unittest.TestCase):
|
||||
max_len_kv,
|
||||
rotary_embs,
|
||||
attn_mask,
|
||||
None,
|
||||
None,
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
cache_k_out_scale,
|
||||
cache_v_out_scale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # linear_shift
|
||||
None, # linear_smooth
|
||||
None, # mask_offset
|
||||
None, # kv_signal_data
|
||||
self.q_norm_weight_tensor if use_qknorm else None, # q_norm_weight
|
||||
self.k_norm_weight_tensor if use_qknorm else None, # k_norm_weight
|
||||
1e-6,
|
||||
"bf16",
|
||||
"none",
|
||||
@@ -271,7 +288,7 @@ class TestTreeMask(unittest.TestCase):
|
||||
paddle.device.synchronize()
|
||||
e_time = time.time()
|
||||
print(f"mean infer time: {np.mean((e_time - s_time) * 1000 / self.run_time):.2f}")
|
||||
return out[0].reshape([token_num, self.num_q_head, self.head_dim])
|
||||
return out.reshape([token_num, self.num_q_head, self.head_dim])
|
||||
|
||||
def test_naive_speculative_decoding(self):
|
||||
prefill_len = 8192
|
||||
@@ -279,10 +296,10 @@ class TestTreeMask(unittest.TestCase):
|
||||
total_len = prefill_len + dec_len_q
|
||||
mask = paddle.tril(paddle.ones((self.bsz, dec_len_q, total_len), dtype="float32"), diagonal=prefill_len)
|
||||
mask = paddle.where(mask == 1, paddle.zeros_like(mask), paddle.full_like(mask, fill_value=float("-inf")))
|
||||
self.run_append_c16_attention(prefill_len, 0, True)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False)
|
||||
self.run_append_c16_attention(prefill_len, 0, True, use_qknorm=self.use_qknorm)
|
||||
dec_out = self.run_append_c16_attention(dec_len_q, prefill_len, False, use_qknorm=self.use_qknorm)
|
||||
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask)
|
||||
ref_out = self.ref_attention(self.CURRENT_Q[0], self.TOTAL_K, self.TOTAL_V, mask, use_qknorm=self.use_qknorm)
|
||||
np.testing.assert_allclose(
|
||||
ref_out.astype("float32").numpy(), dec_out.astype("float32").numpy(), rtol=1e-03, atol=5e-03
|
||||
)
|
||||
|
Reference in New Issue
Block a user