mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00

* 【FIX】Change the name of sparse attn from moba to plas (#4006) * 更新文档 * 【docs】 update readme (#4000) * 更新文档 * update readme * update docs * 【FIX】Change the name of sparse attn from moba to plas (#3845) * 更新文档 * 更新文档 * 更新文档 * 更新文档 * 修改moba为plas * code style * update ci * code style * update ci * code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * fix max_num_seqs * fix test load attn --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
382 lines
14 KiB
Python
382 lines
14 KiB
Python
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
|
|
try:
|
|
from fastdeploy.model_executor.ops.gpu import (
|
|
fused_block_mean_and_rope,
|
|
get_cur_cu_seq_len_k,
|
|
moba_encoder_attn,
|
|
moba_mlp_einsum,
|
|
moba_qk_gemm,
|
|
moba_qk_sort_encoder,
|
|
)
|
|
except:
|
|
moba_attention = None
|
|
get_cur_cu_seq_len_k = None
|
|
import os
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from fastdeploy import LLM, SamplingParams
|
|
|
|
|
|
def naive_attn(q_input, k_input, v_input, mask):
|
|
gqa_group_size = q_input.shape[2] // k_input.shape[2]
|
|
|
|
q_cur = q_input.transpose([0, 2, 1, 3])
|
|
k_cur = k_input.transpose([0, 2, 1, 3])
|
|
v_cur = v_input.transpose([0, 2, 1, 3])
|
|
out = paddle.zeros(q_cur.shape, dtype=q_input.dtype)
|
|
|
|
for bsz in range(0, q_cur.shape[0]):
|
|
for hi in range(0, q_cur.shape[1]):
|
|
qk = paddle.matmul(q_cur[bsz, hi], k_cur[bsz, hi // gqa_group_size].T) * (1.0 / np.sqrt(q_cur.shape[3]))
|
|
qk += mask
|
|
qk_max = qk.max(axis=-1).unsqueeze(-1)
|
|
qk -= qk_max
|
|
qk = qk.exp()
|
|
|
|
exp_sum = qk.sum(axis=-1).unsqueeze(-1)
|
|
exp_sum_inv = 1.0 / exp_sum
|
|
|
|
out[bsz, hi] = (paddle.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
|
|
return out
|
|
|
|
|
|
class TestPlasAttention(unittest.TestCase):
|
|
def setUp(self):
|
|
paddle.seed(0)
|
|
self.seq_len = int(8 * 1024)
|
|
self.num_heads = int(8)
|
|
self.num_kv_heads = int(1)
|
|
self.head_dim = int(128)
|
|
self.max_num_seqs = 1
|
|
self.plas_max_seq_length = int(128 * 1024)
|
|
self.plas_block_size = int(128)
|
|
self.plas_encoder_top_k_left = 2
|
|
self.plas_encoder_top_k_right = 3
|
|
self.plas_use_encoder_seq_limit = int(4 * 1024)
|
|
self.cache_k_block_means = paddle.zeros(
|
|
[
|
|
self.max_num_seqs,
|
|
self.plas_max_seq_length // self.plas_block_size,
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
],
|
|
dtype="bfloat16",
|
|
)
|
|
self.attn_block_m = 128
|
|
self.tokens = self.seq_len * self.max_num_seqs
|
|
self.q_input = paddle.zeros(
|
|
[self.tokens + self.attn_block_m, self.num_heads, self.head_dim],
|
|
dtype="bfloat16",
|
|
)
|
|
self.k_input = paddle.zeros(
|
|
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
|
|
dtype="bfloat16",
|
|
)
|
|
self.v_input = paddle.zeros(
|
|
[self.tokens + self.attn_block_m, self.num_kv_heads, self.head_dim],
|
|
dtype="bfloat16",
|
|
)
|
|
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
|
|
|
|
self.attn_gate_weight = paddle.randn(
|
|
[self.num_kv_heads, self.plas_block_size, self.head_dim], dtype="bfloat16"
|
|
)
|
|
|
|
self.gqa_group_size = self.num_heads // self.num_kv_heads
|
|
|
|
self.num_blocks = (self.seq_len + self.plas_block_size - 1) // self.plas_block_size
|
|
|
|
self.sparse_step = 4
|
|
|
|
def compare_split_qkv_rope(self, qkv_out):
|
|
assert (qkv_out[:, 0 : self.num_heads, :] - self.q_input[0 : self.tokens]).abs().max() < 1e-3
|
|
assert (
|
|
qkv_out[:, self.num_heads : self.num_heads + self.num_kv_heads, :] - self.k_input[0 : self.tokens]
|
|
).abs().max() < 1e-3
|
|
assert (qkv_out[:, self.num_heads + self.num_kv_heads :, :] - self.v_input[0 : self.tokens]).abs().max() < 1e-3
|
|
|
|
for i in range(self.max_num_seqs):
|
|
k_padding = paddle.zeros(
|
|
[
|
|
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
],
|
|
dtype="bfloat16",
|
|
)
|
|
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
|
real_k_block_means = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
|
|
real_k_block_means = real_k_block_means.mean(axis=1)
|
|
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
|
|
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
|
|
|
|
print("[consistency]plas attention: split_qkv_rope matches.")
|
|
|
|
def compare_mlp_einsum(self, k_gate_weight):
|
|
for i in range(self.max_num_seqs):
|
|
k_padding = paddle.zeros(
|
|
[
|
|
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
],
|
|
dtype="bfloat16",
|
|
)
|
|
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
|
k_padding = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
|
|
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
|
|
compute_result = k_gate_weight[i][0 : real_result.shape[0]]
|
|
|
|
assert (real_result - compute_result).abs().max() < 0.5
|
|
|
|
print("[consistency]plas attention: MLP einsum matches.")
|
|
|
|
def compare_qk_gemm(self, qk_gate_weight):
|
|
for i in range(self.max_num_seqs):
|
|
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len]
|
|
k_input_mean = self.cache_k_block_means[i][0 : self.num_blocks]
|
|
|
|
qk_gemm_out = paddle.zeros(
|
|
[
|
|
self.seq_len,
|
|
self.num_heads,
|
|
self.num_blocks,
|
|
],
|
|
dtype="bfloat16",
|
|
)
|
|
|
|
for j in range(self.num_heads):
|
|
qk_gemm_out[:, j, :] = paddle.matmul(
|
|
q_input[:, j, :], k_input_mean[:, j // self.gqa_group_size, :], transpose_y=True
|
|
)
|
|
|
|
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
|
|
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
|
|
|
|
print("[consistency]plas attention: qk_gemm matches.")
|
|
|
|
def compare_qk_gate_topk(self, qk_gate_topk_idx):
|
|
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
|
|
for i in range(self.max_num_seqs):
|
|
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
|
|
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
|
|
|
|
assert (
|
|
qk_gate_topk_idx_batch_no_sparse
|
|
- paddle.ones(qk_gate_topk_idx_batch_no_sparse.shape, qk_gate_topk_idx_batch_no_sparse.dtype)
|
|
).abs().max() < 1e-6
|
|
|
|
for j in range(limit_topk, self.num_blocks):
|
|
qk_gate_topk_idx_batch_sparse = qk_gate_topk_idx_batch[j, :, 1 : (j + 1) // self.sparse_step]
|
|
|
|
assert (
|
|
qk_gate_topk_idx_batch_sparse
|
|
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
|
|
* self.sparse_step
|
|
).abs().max() < 1e-6
|
|
print("[consistency]plas attention: qk_gate_topk matches.")
|
|
|
|
def compare_attn(self, attn_out, qk_gate_topk_idx):
|
|
x = (
|
|
paddle.tensor.triu(paddle.ones([self.plas_block_size, self.plas_block_size], dtype="bfloat16"), 1)
|
|
* -1000000
|
|
)
|
|
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
|
|
for i in range(self.max_num_seqs):
|
|
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
|
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
|
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
|
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
|
|
mask[self.plas_use_encoder_seq_limit - self.plas_block_size :] = -1000000
|
|
for i in range(limit_topk - 1, self.num_blocks):
|
|
n_block = i
|
|
mask[
|
|
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
|
|
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
|
|
] = x
|
|
idx = 0
|
|
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
|
idx += 1
|
|
while n_block >= 0:
|
|
mask[
|
|
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
|
|
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
|
|
] = 0
|
|
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
|
idx += 1
|
|
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
|
|
assert (attn_out - naive_attn_out).abs().max() < 0.016
|
|
|
|
def test_plas_attention(self):
|
|
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
|
|
|
|
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
|
|
seq_len_decoder = paddle.to_tensor([0] * self.max_num_seqs, dtype="int32")
|
|
cu_seq_q = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
|
|
cu_seq_k = paddle.arange(self.max_num_seqs + 1).astype("int32") * self.seq_len
|
|
seq_lens_this_time = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
|
|
|
|
cu_seq_q_pack, cu_seqlens_k, q_pack_tokens = get_cur_cu_seq_len_k(
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
seq_lens_this_time,
|
|
int(self.attn_block_m),
|
|
)
|
|
|
|
fused_block_mean_and_rope(
|
|
qkv_out,
|
|
self.cache_k_block_means,
|
|
self.q_input,
|
|
self.k_input,
|
|
self.v_input,
|
|
self.rotary_embs,
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
cu_seq_q,
|
|
cu_seq_k,
|
|
None,
|
|
self.num_heads,
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
self.plas_max_seq_length,
|
|
self.seq_len,
|
|
self.seq_len,
|
|
"none",
|
|
)
|
|
|
|
self.compare_split_qkv_rope(qkv_out)
|
|
|
|
k_gate_weight = moba_mlp_einsum(
|
|
self.k_input,
|
|
self.attn_gate_weight,
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
cu_seq_k,
|
|
self.seq_len,
|
|
self.num_kv_heads,
|
|
)
|
|
|
|
self.compare_mlp_einsum(k_gate_weight)
|
|
|
|
qk_gate_weight = moba_qk_gemm(
|
|
self.q_input,
|
|
self.cache_k_block_means,
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
cu_seq_q,
|
|
cu_seq_k,
|
|
self.seq_len,
|
|
self.seq_len,
|
|
self.num_heads,
|
|
self.num_kv_heads,
|
|
False,
|
|
self.max_num_seqs,
|
|
)
|
|
|
|
self.compare_qk_gemm(qk_gate_weight)
|
|
|
|
for i in range(0, self.num_blocks, self.sparse_step):
|
|
qk_gate_weight[:, :, i] = 100
|
|
|
|
qk_gate_topk_idx = moba_qk_sort_encoder(
|
|
qk_gate_weight,
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
cu_seq_q,
|
|
cu_seq_k,
|
|
cu_seq_q_pack,
|
|
q_pack_tokens,
|
|
self.seq_len,
|
|
self.seq_len,
|
|
self.num_heads,
|
|
self.num_kv_heads,
|
|
self.plas_encoder_top_k_left,
|
|
self.plas_encoder_top_k_right,
|
|
self.plas_use_encoder_seq_limit,
|
|
)
|
|
|
|
self.compare_qk_gate_topk(qk_gate_topk_idx)
|
|
|
|
attn_out = paddle.zeros([self.tokens, self.num_heads, self.head_dim], dtype="bfloat16")
|
|
|
|
moba_encoder_attn(
|
|
self.q_input,
|
|
self.k_input,
|
|
self.v_input,
|
|
qk_gate_topk_idx,
|
|
cu_seq_q,
|
|
cu_seq_k,
|
|
cu_seq_q_pack,
|
|
seq_len_encoder,
|
|
seq_len_decoder,
|
|
attn_out,
|
|
self.seq_len,
|
|
self.seq_len,
|
|
self.num_heads,
|
|
self.num_kv_heads,
|
|
self.head_dim,
|
|
self.plas_max_seq_length,
|
|
)
|
|
|
|
self.compare_attn(attn_out, qk_gate_topk_idx)
|
|
|
|
def test_server(self):
|
|
if get_cur_cu_seq_len_k is None:
|
|
return
|
|
os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN"
|
|
base_path = os.getenv("MODEL_PATH")
|
|
if base_path:
|
|
model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle")
|
|
else:
|
|
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
|
|
|
|
plas_attention_config = {
|
|
"plas_encoder_top_k_left": 50,
|
|
"plas_encoder_top_k_right": 60,
|
|
"plas_decoder_top_k_left": 100,
|
|
"plas_decoder_top_k_right": 120,
|
|
}
|
|
|
|
# 加载模型
|
|
llm = LLM(
|
|
model=model_path,
|
|
tensor_parallel_size=2,
|
|
max_model_len=131072,
|
|
engine_worker_queue_port=int(os.getenv("FD_ENGINE_QUEUE_PORT")),
|
|
cache_queue_port=int(os.getenv("FD_CACHE_QUEUE_PORT")),
|
|
max_num_seqs=32,
|
|
quantization="wint4",
|
|
enable_chunked_prefill=True,
|
|
max_num_batched_tokens=8192,
|
|
plas_attention_config=plas_attention_config,
|
|
)
|
|
|
|
prompts = ["Hello world!"]
|
|
sampling_params = SamplingParams(temperature=1.0, top_p=0.0, max_tokens=32)
|
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
|
|
for output in outputs:
|
|
print(output.outputs.text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if paddle.is_compiled_with_cuda():
|
|
unittest.main()
|