Files
FastDeploy/tests/layers/test_plas_attention.py
yangjianfengo1 4325b737e7 【FIX】Change the name of sparse attn from moba to plas (#4006) (#4076)
* 【FIX】Change the name of sparse attn from moba to plas (#4006)

* 更新文档

* 【docs】 update readme (#4000)

* 更新文档

* update readme

* update docs

* 【FIX】Change the name of sparse attn from moba to plas (#3845)

* 更新文档

* 更新文档

* 更新文档

* 更新文档

* 修改moba为plas

* code style

* update ci

* code style

* update ci

* code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* fix max_num_seqs

* fix test load attn

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
2025-09-23 10:26:40 +08:00

382 lines
14 KiB
Python

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
try:
from fastdeploy.model_executor.ops.gpu import (
fused_block_mean_and_rope,
get_cur_cu_seq_len_k,
moba_encoder_attn,
moba_mlp_einsum,
moba_qk_gemm,
moba_qk_sort_encoder,
)
except:
moba_attention = None
get_cur_cu_seq_len_k = None
import os
import unittest
import numpy as np
from fastdeploy import LLM, SamplingParams
def naive_attn(q_input, k_input, v_input, mask):
gqa_group_size = q_input.shape[2] // k_input.shape[2]
q_cur = q_input.transpose([0, 2, 1, 3])
k_cur = k_input.transpose([0, 2, 1, 3])
v_cur = v_input.transpose([0, 2, 1, 3])
out = paddle.zeros(q_cur.shape, dtype=q_input.dtype)
for bsz in range(0, q_cur.shape[0]):
for hi in range(0, q_cur.shape[1]):
qk = paddle.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3]))
qk += mask
qk_max = qk.max(axis=-1).unsqueeze(-1)
qk -= qk_max
qk = qk.exp()
exp_sum = qk.sum(axis=-1).unsqueeze(-1)
exp_sum_inv = 1.0 / exp_sum
out[bsz, hi] = (paddle.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
return out
class TestPlasAttention(unittest.TestCase):
def setUp(self):
paddle.seed(0)
self.seq_len = int(8 * 1024)
self.num_heads = int(8)
self.num_kv_heads = int(1)
self.head_dim = int(128)
self.max_num_seqs = 1
self.plas_max_seq_length = int(128 * 1024)
self.plas_block_size = int(128)
self.plas_encoder_top_k_left = 2
self.plas_encoder_top_k_right = 3
self.plas_use_encoder_seq_limit = int(4 * 1024)
self.cache_k_block_means = paddle.zeros(
[
self.max_num_seqs,
self.plas_max_seq_length // self.plas_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
self.attn_block_m = 128
self.tokens = self.seq_len * self.max_num_seqs
self.q_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_heads, self.head_dim],
dtype="bfloat16",
)
self.k_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
dtype="bfloat16",
)
self.v_input = paddle.zeros(
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
dtype="bfloat16",
)
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
self.attn_gate_weight = paddle.randn(
[self.num_kv_heads, self.plas_block_size, self.head_dim], dtype="bfloat16"
)
self.gqa_group_size = self.num_heads // self.num_kv_heads
self.num_blocks = (self.seq_len + self.plas_block_size - 1) // self.plas_block_size
self.sparse_step = 4
def compare_split_qkv_rope(self, qkv_out):
assert (qkv_out[:, 0 : self.num_heads, :] - self.q_input[0 : self.tokens]).abs().max() < 1e-3
assert (
qkv_out[:, self.num_heads : self.num_heads + self.num_kv_heads, :] - self.k_input[0 : self.tokens]
).abs().max() < 1e-3
assert (qkv_out[:, self.num_heads + self.num_kv_heads :, :] - self.v_input[0 : self.tokens]).abs().max() < 1e-3
for i in range(self.max_num_seqs):
k_padding = paddle.zeros(
[
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
real_k_block_means = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
real_k_block_means = real_k_block_means.mean(axis=1)
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
print("[consistency]plas attention: split_qkv_rope matches.")
def compare_mlp_einsum(self, k_gate_weight):
for i in range(self.max_num_seqs):
k_padding = paddle.zeros(
[
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
self.num_kv_heads,
self.head_dim,
],
dtype="bfloat16",
)
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
k_padding = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
compute_result = k_gate_weight[i][0 : real_result.shape[0]]
assert (real_result - compute_result).abs().max() < 0.5
print("[consistency]plas attention: MLP einsum matches.")
def compare_qk_gemm(self, qk_gate_weight):
for i in range(self.max_num_seqs):
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len]
k_input_mean = self.cache_k_block_means[i][0 : self.num_blocks]
qk_gemm_out = paddle.zeros(
[
self.seq_len,
self.num_heads,
self.num_blocks,
],
dtype="bfloat16",
)
for j in range(self.num_heads):
qk_gemm_out[:, j, :] = paddle.matmul(
q_input[:, j, :], k_input_mean[:, j // self.gqa_group_size, :], transpose_y=True
)
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
print("[consistency]plas attention: qk_gemm matches.")
def compare_qk_gate_topk(self, qk_gate_topk_idx):
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
for i in range(self.max_num_seqs):
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
assert (
qk_gate_topk_idx_batch_no_sparse
- paddle.ones(qk_gate_topk_idx_batch_no_sparse.shape, qk_gate_topk_idx_batch_no_sparse.dtype)
).abs().max() < 1e-6
for j in range(limit_topk, self.num_blocks):
qk_gate_topk_idx_batch_sparse = qk_gate_topk_idx_batch[j, :, 1 : (j + 1) // self.sparse_step]
assert (
qk_gate_topk_idx_batch_sparse
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
* self.sparse_step
).abs().max() < 1e-6
print("[consistency]plas attention: qk_gate_topk matches.")
def compare_attn(self, attn_out, qk_gate_topk_idx):
x = (
paddle.tensor.triu(paddle.ones([self.plas_block_size, self.plas_block_size], dtype="bfloat16"), 1)
* -1000000
)
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
for i in range(self.max_num_seqs):
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
mask[self.plas_use_encoder_seq_limit - self.plas_block_size :] = -1000000
for i in range(limit_topk - 1, self.num_blocks):
n_block = i
mask[
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
] = x
idx = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1
while n_block >= 0:
mask[
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
] = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
assert (attn_out - naive_attn_out).abs().max() < 0.016
def test_plas_attention(self):
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
seq_len_decoder = paddle.to_tensor([0] * self.max_num_seqs, dtype="int32")
cu_seq_q = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
cu_seq_k = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
seq_lens_this_time = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
cu_seq_q_pack, cu_seqlens_k, q_pack_tokens = get_cur_cu_seq_len_k(
seq_len_encoder,
seq_len_decoder,
seq_lens_this_time,
int(self.attn_block_m),
)
fused_block_mean_and_rope(
qkv_out,
self.cache_k_block_means,
self.q_input,
self.k_input,
self.v_input,
self.rotary_embs,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
None,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.plas_max_seq_length,
self.seq_len,
self.seq_len,
"none",
)
self.compare_split_qkv_rope(qkv_out)
k_gate_weight = moba_mlp_einsum(
self.k_input,
self.attn_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_k,
self.seq_len,
self.num_kv_heads,
)
self.compare_mlp_einsum(k_gate_weight)
qk_gate_weight = moba_qk_gemm(
self.q_input,
self.cache_k_block_means,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
False,
self.max_num_seqs,
)
self.compare_qk_gemm(qk_gate_weight)
for i in range(0, self.num_blocks, self.sparse_step):
qk_gate_weight[:, :, i] = 100
qk_gate_topk_idx = moba_qk_sort_encoder(
qk_gate_weight,
seq_len_encoder,
seq_len_decoder,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
q_pack_tokens,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
self.plas_encoder_top_k_left,
self.plas_encoder_top_k_right,
self.plas_use_encoder_seq_limit,
)
self.compare_qk_gate_topk(qk_gate_topk_idx)
attn_out = paddle.zeros([self.tokens, self.num_heads, self.head_dim], dtype="bfloat16")
moba_encoder_attn(
self.q_input,
self.k_input,
self.v_input,
qk_gate_topk_idx,
cu_seq_q,
cu_seq_k,
cu_seq_q_pack,
seq_len_encoder,
seq_len_decoder,
attn_out,
self.seq_len,
self.seq_len,
self.num_heads,
self.num_kv_heads,
self.head_dim,
self.plas_max_seq_length,
)
self.compare_attn(attn_out, qk_gate_topk_idx)
def test_server(self):
if get_cur_cu_seq_len_k is None:
return
os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN"
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle")
else:
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
plas_attention_config = {
"plas_encoder_top_k_left": 50,
"plas_encoder_top_k_right": 60,
"plas_decoder_top_k_left": 100,
"plas_decoder_top_k_right": 120,
}
# 加载模型
llm = LLM(
model=model_path,
tensor_parallel_size=2,
max_model_len=131072,
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT")),
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT")),
max_num_seqs=32,
quantization="wint4",
enable_chunked_prefill=True,
max_num_batched_tokens=8192,
plas_attention_config=plas_attention_config,
)
prompts = ["Hello world!"]
sampling_params = SamplingParams(temperature=1.0, top_p=0.0, max_tokens=32)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
for output in outputs:
print(output.outputs.text)
if __name__ == "__main__":
if paddle.is_compiled_with_cuda():
unittest.main()