polish code with new pre-commit rule (#2923)

This commit is contained in:
Zero Rains
2025-07-19 23:19:27 +08:00
committed by GitHub
parent b8676d71a8
commit 25698d56d1
424 changed files with 14307 additions and 13518 deletions

View File

@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy as np
import time
import unittest
import numpy as np
import paddle
paddle.seed(10)
@@ -25,19 +25,16 @@ class RopeEmbedding:
def __init__(self, use_neox_rotary_style=False):
self.use_neox_rotary_style = use_neox_rotary_style
self.base = 10000
def get_neox_style_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim),
dtype="float32")
inv_freq = self.base**(-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"),
inv_freq)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, 1, D]
emb = paddle.concat([freqs, freqs], axis=-1).reshape(
(bsz, max_seq_len, 1, head_dim))
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim))
rot_emb[0] = paddle.cos(emb)
rot_emb[1] = paddle.sin(emb)
@@ -45,21 +42,13 @@ class RopeEmbedding:
def get_rotary_position_embedding(self, position_ids, head_dim):
bsz, max_seq_len = position_ids.shape[:2]
rot_emb = paddle.zeros(
(2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32"
)
inv_freq = self.base ** (
-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim
)
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32")
inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim)
# shape: [B, S, D/2]
freqs = paddle.einsum(
"ij,k->ijk", position_ids.cast("float32"), inv_freq
)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
emb = paddle.stack([freqs], axis=-1).reshape(
(bsz, max_seq_len, head_dim // 2)
)
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
@@ -73,31 +62,39 @@ class RopeEmbedding:
# sin, cos = paddle.chunk(rp, 2, axis=-1)
seq, head_dim = q.shape[2], q.shape[3]
cos, sin = paddle.chunk(rotary_emb, 2, axis=0)
cos = paddle.squeeze(cos, axis=0).transpose(
[0, 2, 1, 3])[:, :, :seq, :]
sin = paddle.squeeze(sin, axis=0).transpose(
[0, 2, 1, 3])[:, :, :seq, :]
cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :]
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
if self.use_neox_rotary_style:
sin_pos = sin
cos_pos = cos
# NeoX Stype前后半部分分块旋转
rotate_half_q = paddle.reshape(
paddle.stack([-q[:, :, :, q.shape[-1]//2:], q[:, :, :, :q.shape[-1]//2]], axis=-1),
paddle.stack(
[
-q[:, :, :, q.shape[-1] // 2 :],
q[:, :, :, : q.shape[-1] // 2],
],
axis=-1,
),
paddle.shape(q),
)
rotate_half_k = paddle.reshape(
paddle.stack([-k[:, :, :, k.shape[-1]//2:], k[:, :, :, :k.shape[-1]//2]], axis=-1),
paddle.stack(
[
-k[:, :, :, k.shape[-1] // 2 :],
k[:, :, :, : k.shape[-1] // 2],
],
axis=-1,
),
paddle.shape(k),
)
else:
# import pdb;pdb.set_trace()
sin_pos = paddle.reshape(paddle.stack(
[sin, sin], axis=-1), [1, 1, seq, head_dim])
sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim])
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
cos_pos = paddle.reshape(paddle.stack(
[cos, cos], axis=-1), [1, 1, seq, head_dim])
cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim])
# GPT Stype奇偶位置分块旋转
rotate_half_q = paddle.reshape(
paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1),
@@ -108,15 +105,9 @@ class RopeEmbedding:
paddle.shape(k),
)
query = paddle.add(
paddle.multiply(q, cos_pos), paddle.multiply(
rotate_half_q, sin_pos)
)
query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos))
key = paddle.add(
paddle.multiply(k, cos_pos), paddle.multiply(
rotate_half_k, sin_pos)
)
key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos))
return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype)
@@ -137,30 +128,19 @@ def create_attn_mask(
for i in range(batch_size):
seq_len = seq_lens[i]
mask[i, 0, :seq_len, :seq_len] = (
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type))
- 1
paddle.tril(paddle.ones(shape=(seq_len, seq_len), dtype=mask_type)) - 1
) * 1e4
return mask
def block_cache_to_naive_cache(
cache_k, cache_v, bsz, block_tables, cache_seq_len
):
def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len):
_, num_head, blocksize, dim_head = cache_k.shape
out_cache_k = paddle.zeros(
shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype
)
out_cache_v = paddle.zeros(
shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype
)
out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype)
out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype)
for i in range(bsz):
for j in range(cache_seq_len):
out_cache_k[i, :, j, :] = cache_k[
block_tables[i, j // blocksize], :, j % blocksize, :
]
out_cache_v[i, :, j, :] = cache_v[
block_tables[i, j // blocksize], :, j % blocksize, :
]
out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :]
out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :]
return out_cache_k, out_cache_v
@@ -209,8 +189,7 @@ def naive_attention_impl(
if mask is not None:
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(paddle.cast(
softmax_result, dtype=value.dtype), value)
result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value)
return result
@@ -235,9 +214,7 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
bsz, num_head, seq_len, dim_head = inputs.shape
output = paddle.zeros(
shape=[token_num, num_head * dim_head], dtype=inputs.dtype
)
output = paddle.zeros(shape=[token_num, num_head * dim_head], dtype=inputs.dtype)
inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1])
for i in range(bsz):
seq_len_now = seq_lens[i]
@@ -248,38 +225,34 @@ def remove_padding(seq_lens, cu_seq_lens, inputs, token_num):
def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head, place, dtype):
query = np.random.random([bs, q_num_head, seq_len, dim_head])/10
q = paddle.to_tensor(
query, place=place, dtype=dtype, stop_gradient=False
)
key = np.random.random([bs, kv_num_head, seq_len, dim_head])/10
k = paddle.to_tensor(
key, place=place, dtype=dtype, stop_gradient=False
)
value = np.random.random([bs, kv_num_head, seq_len, dim_head])/10
v = paddle.to_tensor(
value, place=place, dtype=dtype, stop_gradient=False
)
token_num = bs*seq_len
query = np.random.random([bs, q_num_head, seq_len, dim_head]) / 10
q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False)
key = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False)
value = np.random.random([bs, kv_num_head, seq_len, dim_head]) / 10
v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False)
token_num = bs * seq_len
qkv = paddle.concat(
[
q.transpose([0, 2, 1, 3]).reshape(
[token_num, q_num_head*dim_head]
),
k.transpose([0, 2, 1, 3]).reshape(
[token_num, kv_num_head*dim_head]
),
v.transpose([0, 2, 1, 3]).reshape(
[token_num, kv_num_head*dim_head]
),
q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * dim_head]),
k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * dim_head]),
],
axis=1,
).reshape([token_num, -1])
return q, k, v, qkv
def split_query_by_phase(query, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time, q_dim, k_dim, v_dim):
def split_query_by_phase(
query,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
q_dim,
k_dim,
v_dim,
):
"""
将 query 拆分为 encoder 和 decoder 的 Q/K/V。
"""
@@ -292,8 +265,8 @@ def split_query_by_phase(query, seq_lens_encoder, seq_lens_decoder, seq_lens_thi
query = paddle.reshape(query, [batch, max_seq, total_dim])
# 计算 mask表示该 batch 是否是 encoder/decoder
is_encoder = (seq_lens_encoder > 0).astype('bool').reshape([-1]) # [batch]
is_decoder = (seq_lens_decoder > 0).astype('bool').reshape([-1]) # [batch]
is_encoder = (seq_lens_encoder > 0).astype("bool").reshape([-1]) # [batch]
is_decoder = (seq_lens_decoder > 0).astype("bool").reshape([-1]) # [batch]
# 准备输出列表
enc_qs, enc_ks, enc_vs = [], [], []
@@ -330,8 +303,8 @@ def split_query_by_phase(query, seq_lens_encoder, seq_lens_decoder, seq_lens_thi
return (enc_q, enc_k, enc_v), (dec_q, dec_k, dec_v)
class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
@@ -350,14 +323,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = 'float16'
self.dtype = "float16"
self.init_tensor()
def init_tensor(self):
self.block_num_per_seq = (
self.seq_len + self.max_dec_len + self.blocksize - 1
) // self.blocksize
self.block_num_per_seq = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
self.rope = RopeEmbedding(self.use_neox_rotary_style)
self.max_block_num = self.block_num_per_seq * self.batch_size
self.free_list = list(range(self.max_block_num - 1, -1, -1))
@@ -378,10 +348,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.seq_lens_dec,
"int32",
)
self.max_enc_len_this_time = paddle.to_tensor(
[self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor(
[self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_lens_this_time = self.seq_lens_encoder
self.cache_shape = (
@@ -390,17 +358,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.blocksize,
self.dim_head,
)
self.scale = 1.0 / np.sqrt(self.dim_head)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.block_tables = paddle.zeros(
shape=(self.batch_size, self.block_num_per_seq), dtype="int32"
)
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
for i in range(self.batch_size):
need_block_num = (
self.seq_len + self.max_dec_len + self.blocksize - 1
) // self.blocksize
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
for j in range(need_block_num):
self.block_tables[i, j] = self.free_list.pop()
(
@@ -408,15 +372,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cum_offset,
self.cu_seqlens_q,
self.cu_seqlens_k,
) = get_padding_offset(
self.batch_size, self.seq_len, self.seq_lens_this_time
)
) = get_padding_offset(self.batch_size, self.seq_len, self.seq_lens_this_time)
self.token_num = self.padding_offset.shape[0]
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
self.token_num = self.seq_len*self.batch_size
self.token_num = self.seq_len * self.batch_size
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
self.batch_size,
self.q_num_head,
@@ -424,19 +385,27 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.seq_len,
self.dim_head,
self.place,
self.dtype
self.dtype,
)
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
out_ = naive_attention_impl(
q, k, v, naive_cache_k, naive_cache_v, None, None, attn_mask, self.scale
)
out_ = remove_padding(
self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num
q,
k,
v,
naive_cache_k,
naive_cache_v,
None,
None,
attn_mask,
self.scale,
)
out_ = remove_padding(self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num)
speculate_max_draft_token_num = 1
from fastdeploy.model_executor.layers.attention.ops import append_attention
from fastdeploy.model_executor.layers.attention.ops import get_block_shape_and_split_kv_block
from fastdeploy.model_executor.layers.attention.ops import (
append_attention,
get_block_shape_and_split_kv_block,
)
(
encoder_batch_ids,
@@ -457,15 +426,15 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.cum_offset,
64,
12,
(self.q_num_head + 2*self.kv_num_head) // self.kv_num_head,
(self.q_num_head + 2 * self.kv_num_head) // self.kv_num_head,
self.blocksize,
speculate_max_draft_token_num+1,
speculate_max_draft_token_num + 1,
)
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP+RUN_TIME):
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
@@ -515,17 +484,13 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num+1, # speculate_max_draft_token_num
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)[0]
paddle.device.synchronize()
end_time = time.time()
print(
"[append-attn ut] cost_time:{}ms".format(
(end_time - start_time) / RUN_TIME * 1000
)
)
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
@@ -541,16 +506,12 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
)
def test_all(self):
tmp_position_ids = paddle.arange(
self.seq_len + self.max_dec_len
).reshape((1, -1))
tmp_position_ids = paddle.arange(self.seq_len + self.max_dec_len).reshape((1, -1))
# appendattn 传的是最大maxseq
if self.use_neox_rotary_style:
self.rope_emb = self.rope.get_neox_style_position_embedding(tmp_position_ids, self.dim_head)
else:
self.rope_emb = self.rope.get_rotary_position_embedding(
tmp_position_ids, self.dim_head
)
self.rope_emb = self.rope.get_rotary_position_embedding(tmp_position_ids, self.dim_head)
self.attention_mask = create_attn_mask(
self.dtype,
self.batch_size,
@@ -582,10 +543,8 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
] * self.batch_size
self.max_enc_len_this_time = max(self.seq_lens_enc)
self.max_dec_len_this_time = max(self.seq_lens_dec)
self.max_enc_len_this_time = paddle.to_tensor(
[self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor(
[self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.max_enc_len_this_time = paddle.to_tensor([self.max_enc_len_this_time], "int32", place=paddle.CPUPlace())
self.max_dec_len_this_time = paddle.to_tensor([self.max_dec_len_this_time], "int32", place=paddle.CPUPlace())
self.seq_len = 1
(
@@ -596,6 +555,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time)
self.cmp_append_attention(naive_cache_k, naive_cache_v, None)
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
@@ -615,10 +575,9 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = 'float16'
self.dtype = "float16"
self.init_tensor()
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()