diff --git a/docs/features/images/plas_inference_union.png b/docs/features/images/plas_inference_union.png new file mode 100644 index 000000000..4edc174c1 Binary files /dev/null and b/docs/features/images/plas_inference_union.png differ diff --git a/docs/features/images/plas_training_distill.png b/docs/features/images/plas_training_distill.png new file mode 100644 index 000000000..7b1996bfe Binary files /dev/null and b/docs/features/images/plas_training_distill.png differ diff --git a/docs/features/moba_sparse_attention.md b/docs/features/moba_sparse_attention.md deleted file mode 100644 index 8004bf4ac..000000000 --- a/docs/features/moba_sparse_attention.md +++ /dev/null @@ -1,31 +0,0 @@ -# moba_sparse_attention - -## Introduction - -We propose Lite MoBA and improve it based on MoBA. Specifically, we still draw on the MoE structure to divide KV into multiple blocks, introduce a learnable MLP layer to adaptively select important blocks. We use Full Attention's 1D Max Pooling Attention Map as Ground Truth. Then, we employ KLDivLoss to distill and train the MLP layer weights. Lite MoBA can be directly applied to post - training, where only the weights of the MLP are learnable and the weights of the original model remain unchanged. - -Compared to NSA or MoBA, our Lite MoBA is more scalable and pluggable, without the need to change traditional attention architectures or interfere with model weight training in the Pre - training and Post - training stages. It only requires a small amount of training on the MLP layer in the final stage of the model to achieve almost lossless accuracy. Since MoBA updates the weights of the entire model, even when Full Attention is automatically invoked for inputs shorter than BlockSize x BlockNum, it still cannot avoid the impact of model updates on the model's effectiveness in text processing. In contrast, our pluggable Lite MoBA can achieve Full Attention that is truly equivalent to that of the original model in short text scenarios. - -Compared with MoBA, in terms of effectiveness, its use of Average Pooling to represent inter - block relationships appears relatively limited and has poor handling of outlier representations. Our ablation experiments also demonstrated that the effectiveness of Average Pooling is inferior to that of the learnable MLP. In terms of training performance, since only the MLP weights need to be updated and the model weights do not need to be updated, a large amount of video memory will be saved during training (which needs to be tested). In terms of inference performance, when the input length is 128K, Block Size = 1024, and Block Num = 16, the performance is improved by 322% compared to Flash Attention 3. - -## Usage - -```bash -export FD_ATTENTION_BACKEND="MOBA_ATTN" - -python -m fastdeploy.entrypoints.openai.api_server - --model baidu/ERNIE-4.5-300B-A47B-Paddle \ - --port 8188 \ - --tensor-parallel-size 4 \ - --quantization wint4 \ - --enable-chunked-prefill \ - --max-num-batched-tokens 8192 \ - --max-model-len 131072 \ - --max-num-seqs 32 \ - --moba-attention-config '{"moba_encoder_top_k_left": 60, "moba_encoder_top_k_right": 80, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' -``` -## Environmental Variables Description - -* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention. -* `moba_encoder_top_k_left=60, moba_encoder_top_k_right=80` indicates that the range of top - k is between 80 and 100 when the encoder is sparse. -* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=100` indicates that the range of top - k is between 120 and 140 when the decoder is sparse. diff --git a/docs/features/plas_attention.md b/docs/features/plas_attention.md new file mode 100644 index 000000000..8384de3b5 --- /dev/null +++ b/docs/features/plas_attention.md @@ -0,0 +1,219 @@ +# PLAS + +## Introduction + +We propose **PLAS (Pluggable Lightweight Attention for Sparsity)**, an improvement over MoBA. Specifically, we adopt an MoE-inspired structure that partitions KV into multiple blocks and introduces a learnable MLP layer to adaptively select important blocks. PLAS can be directly applied during post-training, where only the MLP weights are learnable, and the original model weights remain unchanged. + +Compared to NSA/MoBA, our PLAS offers greater scalability and pluggability. It does not require modifying the traditional attention architecture or interfering with model weight training during pre-training or post-training. Only a small amount of training for the MLP layer is needed at the final stage to achieve nearly lossless accuracy. Since NSA/MoBA updates the entire model weights, it inevitably affects performance on short texts—even though it automatically switches to full attention when the input length is shorter than BlockSize × Top-K. In contrast, our PLAS can achieve truly equivalent full attention to the original model in short-text scenarios. + +In terms of training efficiency, the training cost is very low because only the MLP weight needs to be updated. For inference performance, when the input length is 128K, Block Size = 128, and Top-K = 55, PLAS achieves a **386% speedup** compared to Flash Attention 3. + +## Method + +### Training + +Following the approaches of NSA and MoBA, we partition the KV into multiple blocks. During both the prefill and decode stages, instead of performing attention computation over all KV, we dynamically select the top-K blocks with the highest attention scores for each query token, thereby enabling efficient sparse attention computation. + +
+Attention Gate Module +
+ +* **Attention Gate Module**: As illustrated in the figure above, to estimate the importance of each block with low computational overhead, we design a lightweight attention gate module. This module first compresses each K block via a MLP layer to generate a representative low-dimensional representation: $K_c^T=W_{kp}K^T$, where $W_{kp}$ denotes the MLP layer weights. Compared to directly applying mean pooling, the learnable MLP can more effectively capture semantic relationships and importance distributions among different tokens, thereby providing a refined representation of each block. After obtaining the compressed representation $K_c$, the importance of each query token with respect to each block is estimated via: $Softmax(Q\cdot K_c^T)$. To enhance the discriminative ability of the MLP layer, we use the full attention result after 1D max pooling $1DMaxPooling(Softmax(Q \cdot K^T))$ as the ground truth. By minimizing the distribution divergence between the two, the MLP layer is guided to learn feature representations that better align with the true attention distribution. +* **Training Data**: Benefiting from the efficiency of both the model architecture and the training paradigm, our approach achieves near-lossless precision with only 1B tokens used for training. The training data is sourced from an internally constructed mixed corpus containing both long and short texts, thereby enhancing the module’s adaptability to varying sequence lengths. +* **Other**: We observe that the final decode layer has a significant impact on the overall model accuracy. Therefore, during training, we exclude this layer from sparse attention computation and revert to full attention for this layer during inference. + +### Inference + +During sparse attention computation, each query token may dynamically select different KV blocks, leading to highly irregular memory access patterns in HBM. It is feasible to simply process each query token separately, but it will lead to excessively fine-grained computing, which cannot make full use of the tensor core, thus significantly reducing the GPU computing efficiency. + +
+Token/Head Union +
+ +To optimize performance in both the prefill and decode stages, we design a special joint strategy to adapt to their respective characteristics: + +* **Prefill Toke Union**: We observe that adjacent query tokens tend to select similar key blocks. Leveraging this locality, we take the union of the key blocks selected by consecutive 128 query tokens and jointly compute sparse attention for these tokens. +* **Decode Head Union**: Given the widespread adoption of GQA in modern models, we find that different heads within the same group often select overlapping key blocks. Thus, we combine the key blocks selected by all query heads within a group into a unified set and jointly calculate sparse attention. This way also reduces memory access overhead and further improves decoding efficiency. +* **Top-K Selection**: Conventional top-k algorithms based on sorting or direct calls to the cub library introduce significant runtime overhead. To mitigate this, we implemented an approximate top-k selection algorithm using binary search, which significantly reduces latency while maintaining accuracy, ultimately achieving significantly improved performance. + +## Evaluation + +### Experiments + +We evaluated the precision of full attention and sparse attention on LongBenchV2 and Ruler (with context lengths of 32K, 64K, and 128K). + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Model + + Precision +
+ FullAttention + + SparseAttention +
+ LongBenchV2 + + Ruler + + LongBenchV2 + + Ruler +
+ 32K + + 64K + + 128K + + 32K + + 64K + + 128K +
+ ERNIE-4.5-21B-A3B + 31.4876.7456.4025.4831.4575.9355.3825.05
+ ERNIE-4.5-300B-A47B + 41.0294.7083.5658.1841.0594.5082.3257.85
+ +### Performance + +We selected a subset (longbook_sum_eng) from InfiniteBench as the performance evaluation dataset. For inputs exceeding 128K in length, we truncate the sequence by keeping the first 64K and the last 64K tokens. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
QPSDecode Speed (token/s)Time to First token(s)Time per Ouput Token(ms)End-to-End Latency(s)Mean Input
Length
Mean Output Length
+ ERNIE-4.5-21B-A3B + + FullAttention + 0.10113.328.08287.0561.400113182.32627.76
+ SparseAttention + 0.150(+48%)18.12(+36%)5.466(-48%)66.35(-31%)42.157(-46%)113182.32590.23
+ ERNIE-4.5-300B-A47B + + FullAttention + 0.0665.0713.812206.70164.704113182.32725.97
+ SparseAttention + 0.081(+23%)6.75(+33%)10.584(-30%)154.84(-34%)132.745(-24%)113182.32748.25
+ +## Usage + +``` +export FD_ATTENTION_BACKEND="MOBA_ATTN" + +python -m fastdeploy.entrypoints.openai.api_server + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8188 \ + --tensor-parallel-size 4 \ + --quantization wint4 \ + --enable-chunked-prefill \ + --max-num-batched-tokens 8192 \ + --max-model-len 131072 \ + --max-num-seqs 32 \ + --moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' +``` + +**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations. + +**Parameter Description:** + +* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention. +* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse. +* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse. diff --git a/docs/zh/features/images/plas_inference_union.png b/docs/zh/features/images/plas_inference_union.png new file mode 100644 index 000000000..4edc174c1 Binary files /dev/null and b/docs/zh/features/images/plas_inference_union.png differ diff --git a/docs/zh/features/images/plas_training_distill.png b/docs/zh/features/images/plas_training_distill.png new file mode 100644 index 000000000..7b1996bfe Binary files /dev/null and b/docs/zh/features/images/plas_training_distill.png differ diff --git a/docs/zh/features/plas_attention.md b/docs/zh/features/plas_attention.md new file mode 100644 index 000000000..09a98e6f4 --- /dev/null +++ b/docs/zh/features/plas_attention.md @@ -0,0 +1,223 @@ +# PLAS + +## 介绍 + +我们提出了**PLAS(Pluggable Lightweight Attention for Sparsity)**,这是对 MoBA 的改进。具体来说,我们采用了受 MoE 启发的结构,将 KV 划分为多个块,并引入了一个可学习的 MLP 层来自适应地选择重要块。PLAS 可以直接在训练后应用,此时只有 MLP 权重可学习,而原始模型权重保持不变。 + +与 NSA/MoBA 相比,我们的 PLAS 具有更高的可扩展性和可插拔性。它无需修改传统的注意力架构,也无需在训练前或训练后干扰模型权重训练。最终阶段只需对 MLP 层进行少量训练即可实现几乎无损的准确率。由于 NSA/MoBA 会更新整个模型权重,因此不可避免地会影响短文本的性能——即使它在输入长度小于 BlockSize × Top-K 时会自动切换到完全注意力机制。相比之下,我们的 PLAS 在短文本场景下可以实现与原始模型真正等同的完全注意力机制。 + +在训练效率方面,由于仅需更新 MLP 权重,训练成本极低。在推理性能方面,当输入长度为 128K、Block Size = 128、Top-K = 55 时,PLAS 相比 Flash Attention 3 实现了**386% 的加速**。 + +## 方法 + +### 训练 + +借鉴 NSA 和 MoBA 的方法,我们将键值对 (KV) 划分为多个块。在预填充和解码阶段,我们不再对所有键值进行注意力计算,而是动态地为每个查询 token 选择注意力得分最高的前 K 个块,从而实现高效的稀疏注意力计算。 + +
+Attention Gate Module +
+ +* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。 + +* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。 + +* **Other**: 我们观察到,最终的解码层对模型整体准确率有显著影响。因此,在训练过程中,我们将该层排除在稀疏注意力计算之外,并在推理过程中将其恢复为完全注意力。 + +### 推理优化 + +在稀疏注意力计算过程中,每个查询 token 可能会动态选择不同的 KV 块,导致 HBM 的内存访问模式非常不规则。简单地对每个查询 token 进行单独处理是可行的,但这会导致计算粒度过细,无法充分利用张量核,从而显著降低 GPU 的计算效率。 + +
+Token/Head Union +
+ +为了优化预填充和解码阶段的性能,我们设计了一种特殊的联合策略来适应各自的特点: + +* **Prefill Toke Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。 + +* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。 + +* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。 + +## 评估 + +### 实验 + +我们在 LongBenchV2 和 Ruler(上下文长度分别为 32K、64K 和 128K)上评估了全注意力和稀疏注意力的精度。 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ Model + + Precision +
+ FullAttention + + SparseAttention +
+ LongBenchV2 + + Ruler + + LongBenchV2 + + Ruler +
+ 32K + + 64K + + 128K + + 32K + + 64K + + 128K +
+ ERNIE-4.5-21B-A3B + 31.4876.7456.4025.4831.4575.9355.3825.05
+ ERNIE-4.5-300B-A47B + 41.0294.7083.5658.1841.0594.5082.3257.85
+ +### 性能 + +我们从 InfiniteBench 中选择了一个子集 (longbook_sum_eng) 作为性能评估数据集。对于长度超过 128K 的输入,我们截断序列,保留前 64K 和后 64K 个 token。 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
QPSDecode Speed (token/s)Time to First token(s)Time per Ouput Token(ms)End-to-End Latency(s)Mean Input
Length
Mean Output Length
+ ERNIE-4.5-21B-A3B + + FullAttention + 0.10113.328.08287.0561.400113182.32627.76
+ SparseAttention + 0.150(+48%)18.12(+36%)5.466(-48%)66.35(-31%)42.157(-46%)113182.32590.23
+ ERNIE-4.5-300B-A47B + + FullAttention + 0.0665.0713.812206.70164.704113182.32725.97
+ SparseAttention + 0.081(+23%)6.75(+33%)10.584(-30%)154.84(-34%)132.745(-24%)113182.32748.25
+ +## 使用方式 + +``` +export FD_ATTENTION_BACKEND="MOBA_ATTN" + +python -m fastdeploy.entrypoints.openai.api_server + --model baidu/ERNIE-4.5-300B-A47B-Paddle \ + --port 8188 \ + --tensor-parallel-size 4 \ + --quantization wint4 \ + --enable-chunked-prefill \ + --max-num-batched-tokens 8192 \ + --max-model-len 131072 \ + --max-num-seqs 32 \ + --moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' +``` + +**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化 + +**Parameter Description:** + +* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention. +* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。 +* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。