mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] block sparse attention (#3209)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* 支持稀疏attn * fix bug * code style * fix moba attn get kv shape * 修复a100编译 * codestyle * code style * code style * code style * fix conflict * 增加单侧 * code style * 增加eblite 加载时间 * fix bug * for ci * for ci * for ci * for ci * 支持mlp block size 128 * 增加小算子单测 * fix 单测 mlp * 将环境变量加入到config里面 * fix rollout config
This commit is contained in:
@@ -28,6 +28,11 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
import os
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||
|
||||
|
||||
@@ -113,6 +118,42 @@ class Attention(nn.Layer):
|
||||
self.k_norm_key = f"{self.prefix}.k_norm"
|
||||
self.init_weight()
|
||||
|
||||
if fd_config.moba_attention_config is not None:
|
||||
mlp_weight_path = os.path.join(
|
||||
fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name
|
||||
)
|
||||
self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
|
||||
moba_block_size = fd_config.moba_attention_config.moba_block_size
|
||||
moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
|
||||
if self.moba_use_mlp:
|
||||
mlp_weight = {}
|
||||
with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
|
||||
for key_name in f.keys():
|
||||
weight = f.get_tensor(key_name)
|
||||
weight = paddle.Tensor(weight, zero_copy=True)
|
||||
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
|
||||
mlp_weight[key_name] = weight
|
||||
|
||||
if self.layer_id < fd_config.model_config.num_hidden_layers - 1:
|
||||
self.attn_gate_weight = mlp_weight[
|
||||
f"ernie.layers.{self.layer_id}.self_attn.attn_gate.weight"
|
||||
].astype(paddle.get_default_dtype())[
|
||||
fd_config.parallel_config.tensor_parallel_rank
|
||||
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
|
||||
* self.kv_num_heads
|
||||
]
|
||||
assert self.attn_gate_weight.shape[1] % moba_block_size == 0
|
||||
|
||||
self.cache_k_block_means = paddle.zeros(
|
||||
[
|
||||
fd_config.parallel_config.max_num_seqs,
|
||||
moba_max_seq_length // moba_block_size,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
|
||||
def init_weight(self):
|
||||
self.q_norm_weight = self.create_parameter(
|
||||
shape=[self.qk_head_dim],
|
||||
|
Reference in New Issue
Block a user