diff --git a/docs/features/plas_attention.md b/docs/features/plas_attention.md index dfd85e676..6551b4e8f 100644 --- a/docs/features/plas_attention.md +++ b/docs/features/plas_attention.md @@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev ## Usage ``` -export FD_ATTENTION_BACKEND="MOBA_ATTN" +export FD_ATTENTION_BACKEND="PLAS_ATTN" python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-300B-A47B-Paddle \ @@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server --max-num-batched-tokens 8192 \ --max-model-len 131072 \ --max-num-seqs 32 \ - --moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' + --plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}' ``` -**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations. +**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations. **Parameter Description:** -* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention. -* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse. -* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse. +* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention. +* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse. +* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse. diff --git a/docs/zh/features/plas_attention.md b/docs/zh/features/plas_attention.md index f415f49b1..a49cb25fd 100644 --- a/docs/zh/features/plas_attention.md +++ b/docs/zh/features/plas_attention.md @@ -18,7 +18,7 @@ Attention Gate Module -* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。 +* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。 * **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。 @@ -36,7 +36,7 @@ * **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。 -* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。 +* **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。 * **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。 @@ -200,7 +200,7 @@ ## 使用方式 ``` -export FD_ATTENTION_BACKEND="MOBA_ATTN" +export FD_ATTENTION_BACKEND="PLAS_ATTN" python -m fastdeploy.entrypoints.openai.api_server --model baidu/ERNIE-4.5-300B-A47B-Paddle \ @@ -211,13 +211,13 @@ python -m fastdeploy.entrypoints.openai.api_server --max-num-batched-tokens 8192 \ --max-model-len 131072 \ --max-num-seqs 32 \ - --moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' + --plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}' ``` -**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化 +**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化 **Parameter Description:** -* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention. -* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。 -* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。 +* `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention. +* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。 +* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。 diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 6dcf87a36..e9df1c52e 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -945,63 +945,63 @@ class GraphOptimizationConfig: argument = self.use_cudagraph -class MobaAttentionConfig: +class PlasAttentionConfig: def __init__( self, args, ): - self.moba_encoder_top_k_left: int = None - self.moba_encoder_top_k_right: int = None - "The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]" - self.moba_decoder_top_k_left: int = None - self.moba_decoder_top_k_right: int = None - "The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]" - self.moba_use_encoder_seq_limit: int = None - "When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse" - self.moba_use_decoder_seq_limit: int = None - "When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse" - self.moba_block_size: int = 128 - self.mlp_weight_name: str = "moba_mlp_weight.safetensors" - self.moba_max_seq_length: int = 128 * 1024 + self.plas_encoder_top_k_left: int = None + self.plas_encoder_top_k_right: int = None + "The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]" + self.plas_decoder_top_k_left: int = None + self.plas_decoder_top_k_right: int = None + "The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]" + self.plas_use_encoder_seq_limit: int = None + "When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse" + self.plas_use_decoder_seq_limit: int = None + "When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse" + self.plas_block_size: int = 128 + self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors" + self.plas_max_seq_length: int = 128 * 1024 if args is not None: for key, value in args.items(): if hasattr(self, key): setattr(self, key, value) - if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None: - self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size - if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None: - self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size + if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None: + self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size + if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None: + self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size self.check_legality_parameters() def check_legality_parameters( self, ) -> None: - if self.moba_encoder_top_k_left is not None: - assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0" + if self.plas_encoder_top_k_left is not None: + assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0" - if self.moba_encoder_top_k_right is not None: - assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0" + if self.plas_encoder_top_k_right is not None: + assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0" assert ( - self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left - ), "moba_encoder_top_k_right must large than moba_encoder_top_k_left" + self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left + ), "plas_encoder_top_k_right must large than plas_encoder_top_k_left" - if self.moba_decoder_top_k_left is not None: - assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0" + if self.plas_decoder_top_k_left is not None: + assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0" - if self.moba_decoder_top_k_right is not None: - assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0" + if self.plas_decoder_top_k_right is not None: + assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0" assert ( - self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left - ), "moba_decoder_top_k_right must large than moba_decoder_top_k_left" + self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left + ), "plas_decoder_top_k_right must large than plas_decoder_top_k_left" - if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None: - assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size - if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None: - assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size + if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None: + assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size + if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None: + assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size def to_json_string(self): """ - Convert moba_attention_config to json string. + Convert plas_attention_config to json string. """ return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) @@ -1396,7 +1396,7 @@ class FDConfig: decoding_config: DecodingConfig = None, quant_config: QuantConfigBase = None, graph_opt_config: GraphOptimizationConfig = None, - moba_attention_config: MobaAttentionConfig = None, + plas_attention_config: PlasAttentionConfig = None, speculative_config: SpeculativeConfig = None, tokenizer: str = None, max_model_len: int = 8192, @@ -1427,7 +1427,7 @@ class FDConfig: self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config self.decoding_config: DecodingConfig = decoding_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore - self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config + self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 82dacc1c2..418b424c2 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -30,9 +30,9 @@ from fastdeploy.config import ( FDConfig, GraphOptimizationConfig, LoadConfig, - MobaAttentionConfig, ModelConfig, ParallelConfig, + PlasAttentionConfig, PoolerConfig, RunnerOption, SpeculativeConfig, @@ -361,9 +361,9 @@ class EngineArgs: """ Configuration for graph optimization backend execution. """ - moba_attention_config: Optional[Dict[str, Any]] = None + plas_attention_config: Optional[Dict[str, Any]] = None """ - Configuration for moba attention. + Configuration for plas attention. """ enable_logprob: bool = False @@ -601,9 +601,9 @@ class EngineArgs: help="", ) model_group.add_argument( - "--moba-attention-config", + "--plas-attention-config", type=json.loads, - default=EngineArgs.moba_attention_config, + default=EngineArgs.plas_attention_config, help="", ) model_group.add_argument( @@ -993,17 +993,17 @@ class EngineArgs: graph_optimization_args[k] = v return GraphOptimizationConfig(graph_optimization_args) - def create_moba_attention_config(self) -> MobaAttentionConfig: + def create_plas_attention_config(self) -> PlasAttentionConfig: """ - Create and retuan a MobaAttentionConfig object based on the current settings. + Create and retuan a PlasAttentionConfig object based on the current settings. """ attention_args = asdict(self) - if self.moba_attention_config is not None: - for k, v in self.moba_attention_config.items(): + if self.plas_attention_config is not None: + for k, v in self.plas_attention_config.items(): attention_args[k] = v - return MobaAttentionConfig(attention_args) + return PlasAttentionConfig(attention_args) else: - return MobaAttentionConfig(None) + return PlasAttentionConfig(None) def create_early_stop_config(self) -> EarlyStopConfig: """ @@ -1064,7 +1064,7 @@ class EngineArgs: scheduler_cfg = self.create_scheduler_config() graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) - moba_attention_config = self.create_moba_attention_config() + plas_attention_config = self.create_plas_attention_config() early_stop_cfg = self.create_early_stop_config() early_stop_cfg.update_enable_early_stop(self.enable_early_stop) @@ -1093,7 +1093,7 @@ class EngineArgs: max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, graph_opt_config=graph_opt_cfg, - moba_attention_config=moba_attention_config, + plas_attention_config=plas_attention_config, guided_decoding_backend=self.guided_decoding_backend, disable_any_whitespace=self.guided_decoding_disable_any_whitespace, early_stop_config=early_stop_cfg, diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 84890b1e1..8689490d8 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -501,7 +501,7 @@ class LLMEngine: f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --reasoning_parser {self.cfg.reasoning_parser}" f" --load_choices {self.cfg.load_config.load_choices}" - f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" + f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'" f" --ips {ips}" f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" f" --runner {self.cfg.model_config.runner}" diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7157ac63b..cbc6152aa 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -20,7 +20,7 @@ from .block_multihead_attn_backend import BlockAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .iluvatar_attn_backend import IluvatarAttnBackend from .mla_attention_backend import MLAAttentionBackend -from .moba_attention_backend import MobaAttentionBackend +from .moba_attention_backend import PlasAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend @@ -35,5 +35,5 @@ __all__ = [ "IluvatarAttnBackend", "BlockAttentionBackend", "Attention", - "MobaAttentionBackend", + "PlasAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index d3730c9f3..4c3352868 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -119,19 +119,19 @@ class Attention(nn.Layer): self.init_weight() if ( - fd_config.moba_attention_config is not None - and fd_config.moba_attention_config.moba_encoder_top_k_left is not None - and fd_config.moba_attention_config.moba_encoder_top_k_right is not None - and fd_config.moba_attention_config.moba_decoder_top_k_left is not None - and fd_config.moba_attention_config.moba_decoder_top_k_right is not None + fd_config.plas_attention_config is not None + and fd_config.plas_attention_config.plas_encoder_top_k_left is not None + and fd_config.plas_attention_config.plas_encoder_top_k_right is not None + and fd_config.plas_attention_config.plas_decoder_top_k_left is not None + and fd_config.plas_attention_config.plas_decoder_top_k_right is not None ): mlp_weight_path = os.path.join( - fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name + fd_config.model_config.model, fd_config.plas_attention_config.mlp_weight_name ) - self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path) - moba_block_size = fd_config.moba_attention_config.moba_block_size - moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length - if self.moba_use_mlp: + self.plas_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path) + plas_block_size = fd_config.plas_attention_config.plas_block_size + plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length + if self.plas_use_mlp: mlp_weight = {} with safe_open(mlp_weight_path, framework="np", device="cpu") as f: for key_name in f.keys(): @@ -148,12 +148,12 @@ class Attention(nn.Layer): * self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1) * self.kv_num_heads ] - assert self.attn_gate_weight.shape[1] % moba_block_size == 0 + assert self.attn_gate_weight.shape[1] % plas_block_size == 0 self.cache_k_block_means = paddle.zeros( [ fd_config.scheduler_config.max_num_seqs, - moba_max_seq_length // moba_block_size, + plas_max_seq_length // plas_block_size, self.kv_num_heads, self.head_dim, ], diff --git a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py index 47c65624f..82ac4880b 100644 --- a/fastdeploy/model_executor/layers/attention/moba_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/moba_attention_backend.py @@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @dataclass -class MobaAttentionMetadata(AttentionMetadata): +class PlasAttentionMetadata(AttentionMetadata): """ AppendAttentionMetadata """ @@ -54,7 +54,7 @@ class MobaAttentionMetadata(AttentionMetadata): max_dec_len_this_time: int = 0 -class MobaAttentionBackend(AttentionBackend): +class PlasAttentionBackend(AttentionBackend): """ The backend class that uses paddle native attention implementation. Which is used only for testing purpose. @@ -70,11 +70,11 @@ class MobaAttentionBackend(AttentionBackend): decoder_block_shape_q: int = -1, ) -> None: """ - MobaAttentionBackend __init__ + PlasAttentionBackend __init__ """ super().__init__() - self.attention_metadata: MobaAttentionMetadata = None - assert fd_config.moba_attention_config is not None, "moba_attention_config is None" + self.attention_metadata: PlasAttentionMetadata = None + assert fd_config.plas_attention_config is not None, "plas_attention_config is None" self.block_size = fd_config.parallel_config.block_size self.max_seq_len = fd_config.parallel_config.max_model_len self.max_num_seqs = fd_config.scheduler_config.max_num_seqs @@ -83,18 +83,18 @@ class MobaAttentionBackend(AttentionBackend): self.head_dim = fd_config.model_config.head_dim self.num_layers: int = fd_config.model_config.num_hidden_layers self.attn_block_m = 128 - self.moba_block_size = fd_config.moba_attention_config.moba_block_size - self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left) - self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right) - self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit) - self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left) - self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right) - self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit) - self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length + self.plas_block_size = fd_config.plas_attention_config.plas_block_size + self.plas_encoder_top_k_left = int(fd_config.plas_attention_config.plas_encoder_top_k_left) + self.plas_encoder_top_k_right = int(fd_config.plas_attention_config.plas_encoder_top_k_right) + self.plas_use_encoder_seq_limit = int(fd_config.plas_attention_config.plas_use_encoder_seq_limit) + self.plas_decoder_top_k_left = int(fd_config.plas_attention_config.plas_decoder_top_k_left) + self.plas_decoder_top_k_right = int(fd_config.plas_attention_config.plas_decoder_top_k_right) + self.plas_use_decoder_seq_limit = int(fd_config.plas_attention_config.plas_use_decoder_seq_limit) + self.plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length def init_attention_metadata(self, forward_meta: ForwardMeta): """Init the metadata for a forward pass.""" - metadata = MobaAttentionMetadata() + metadata = PlasAttentionMetadata() metadata._dtype = paddle.get_default_dtype() metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k( forward_meta.seq_lens_encoder, @@ -116,7 +116,7 @@ class MobaAttentionBackend(AttentionBackend): [k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype ) self.attention_metadata = metadata - assert self.max_seq_len <= self.moba_max_seq_length + assert self.max_seq_len <= self.plas_max_seq_length def get_kv_cache_shape( self, @@ -186,13 +186,13 @@ class MobaAttentionBackend(AttentionBackend): self.max_seq_len, attention_metadata.max_enc_len_this_time, attention_metadata.max_dec_len_this_time, - self.moba_encoder_top_k_left, - self.moba_encoder_top_k_right, - self.moba_use_encoder_seq_limit, - self.moba_decoder_top_k_left, - self.moba_decoder_top_k_right, - self.moba_use_decoder_seq_limit, - layer.moba_use_mlp, + self.plas_encoder_top_k_left, + self.plas_encoder_top_k_right, + self.plas_use_encoder_seq_limit, + self.plas_decoder_top_k_left, + self.plas_decoder_top_k_right, + self.plas_use_decoder_seq_limit, + layer.plas_use_mlp, getattr(layer, "cache_quant_type_str", "none"), )[0] return out diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index a0e13f9c7..478bb7b62 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -26,7 +26,7 @@ class _Backend(enum.Enum): MLA_ATTN = enum.auto() FLASH_ATTN = enum.auto() BLOCK_ATTN = enum.auto() - MOBA_ATTN = enum.auto() + PLAS_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index a9e070755..9720e7ace 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -64,9 +64,9 @@ class CUDAPlatform(Platform): elif selected_backend == _Backend.FLASH_ATTN: logger.info("Using FLASH ATTN backend.") return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend" - elif selected_backend == _Backend.MOBA_ATTN: - logger.info("Using MOBA ATTN backend.") - return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend" + elif selected_backend == _Backend.PLAS_ATTN: + logger.info("Using PLAS ATTN backend.") + return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/fastdeploy/rl/rollout_config.py b/fastdeploy/rl/rollout_config.py index 9a6290c6a..9ae5883d0 100644 --- a/fastdeploy/rl/rollout_config.py +++ b/fastdeploy/rl/rollout_config.py @@ -61,7 +61,7 @@ class RolloutModelConfig: graph_optimization_config: str = None, early_stop_config: str = None, local_rank: int = 0, - moba_attention_config: str = None, + plas_attention_config: str = None, data_parallel_size: int = 1, num_nextn_predict_layers: int = 0, ): @@ -109,7 +109,7 @@ class RolloutModelConfig: self.local_rank = local_rank self.early_stop_config = early_stop_config self.ips = None - self.moba_attention_config = moba_attention_config + self.plas_attention_config = plas_attention_config self.num_nextn_predict_layers = num_nextn_predict_layers def __str__(self): diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 186dd58ea..902bf9461 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -35,9 +35,9 @@ from fastdeploy.config import ( FDConfig, GraphOptimizationConfig, LoadConfig, - MobaAttentionConfig, ModelConfig, ParallelConfig, + PlasAttentionConfig, SpeculativeConfig, ) from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer @@ -577,10 +577,10 @@ def parse_args(): help="Configuration of Graph optimization backend.", ) parser.add_argument( - "--moba_attention_config", + "--plas_attention_config", type=json.loads, default=None, - help="Configuration of moba attention.", + help="Configation of plas attention.", ) parser.add_argument( "--guided_decoding_backend", @@ -723,7 +723,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config) - moba_attention_config = MobaAttentionConfig(args.moba_attention_config) + plas_attention_config = PlasAttentionConfig(args.plas_attention_config) early_stop_config = EarlyStopConfig(args.early_stop_config) @@ -795,7 +795,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: cache_config=cache_config, scheduler_config=scheduler_config, ips=args.ips, - moba_attention_config=moba_attention_config, + plas_attention_config=plas_attention_config, ) update_fd_config_for_mm(fd_config) diff --git a/tests/layers/test_moba_attention.py b/tests/layers/test_plas_attention.py similarity index 83% rename from tests/layers/test_moba_attention.py rename to tests/layers/test_plas_attention.py index b19485042..9de05578a 100644 --- a/tests/layers/test_moba_attention.py +++ b/tests/layers/test_plas_attention.py @@ -57,7 +57,7 @@ def naive_attn(q_input, k_input, v_input, mask): return out -class TestMobaAttention(unittest.TestCase): +class TestPlasAttention(unittest.TestCase): def setUp(self): paddle.seed(0) self.seq_len = int(8 * 1024) @@ -65,15 +65,15 @@ class TestMobaAttention(unittest.TestCase): self.num_kv_heads = int(1) self.head_dim = int(128) self.max_num_seqs = 1 - self.moba_max_seq_length = int(128 * 1024) - self.moba_block_size = int(128) - self.moba_encoder_top_k_left = 2 - self.moba_encoder_top_k_right = 3 - self.moba_use_encoder_seq_limit = int(4 * 1024) + self.plas_max_seq_length = int(128 * 1024) + self.plas_block_size = int(128) + self.plas_encoder_top_k_left = 2 + self.plas_encoder_top_k_right = 3 + self.plas_use_encoder_seq_limit = int(4 * 1024) self.cache_k_block_means = paddle.zeros( [ self.max_num_seqs, - self.moba_max_seq_length // self.moba_block_size, + self.plas_max_seq_length // self.plas_block_size, self.num_kv_heads, self.head_dim, ], @@ -96,12 +96,12 @@ class TestMobaAttention(unittest.TestCase): self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32") self.attn_gate_weight = paddle.randn( - [self.num_kv_heads, self.moba_block_size, self.head_dim], dtype="bfloat16" + [self.num_kv_heads, self.plas_block_size, self.head_dim], dtype="bfloat16" ) self.gqa_group_size = self.num_heads // self.num_kv_heads - self.num_blocks = (self.seq_len + self.moba_block_size - 1) // self.moba_block_size + self.num_blocks = (self.seq_len + self.plas_block_size - 1) // self.plas_block_size self.sparse_step = 4 @@ -115,38 +115,38 @@ class TestMobaAttention(unittest.TestCase): for i in range(self.max_num_seqs): k_padding = paddle.zeros( [ - (self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, + (self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size, self.num_kv_heads, self.head_dim, ], dtype="bfloat16", ) k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] - real_k_block_means = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) + real_k_block_means = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim]) real_k_block_means = real_k_block_means.mean(axis=1) compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]] assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003 - print("[consistency]Moba attention: split_qkv_rope matches.") + print("[consistency]plas attention: split_qkv_rope matches.") def compare_mlp_einsum(self, k_gate_weight): for i in range(self.max_num_seqs): k_padding = paddle.zeros( [ - (self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, + (self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size, self.num_kv_heads, self.head_dim, ], dtype="bfloat16", ) k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] - k_padding = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) + k_padding = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim]) real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight) compute_result = k_gate_weight[i][0 : real_result.shape[0]] assert (real_result - compute_result).abs().max() < 0.5 - print("[consistency]Moba attention: MLP einsum matches.") + print("[consistency]plas attention: MLP einsum matches.") def compare_qk_gemm(self, qk_gate_weight): for i in range(self.max_num_seqs): @@ -170,10 +170,10 @@ class TestMobaAttention(unittest.TestCase): conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks] assert (qk_gemm_out - conpute_result).abs().max() < 1e-4 - print("[consistency]Moba attention: qk_gemm matches.") + print("[consistency]plas attention: qk_gemm matches.") def compare_qk_gate_topk(self, qk_gate_topk_idx): - limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size + limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size for i in range(self.max_num_seqs): qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks] qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1] @@ -191,40 +191,40 @@ class TestMobaAttention(unittest.TestCase): - paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype) * self.sparse_step ).abs().max() < 1e-6 - print("[consistency]Moba attention: qk_gate_topk matches.") + print("[consistency]plas attention: qk_gate_topk matches.") def compare_attn(self, attn_out, qk_gate_topk_idx): x = ( - paddle.tensor.triu(paddle.ones([self.moba_block_size, self.moba_block_size], dtype="bfloat16"), 1) + paddle.tensor.triu(paddle.ones([self.plas_block_size, self.plas_block_size], dtype="bfloat16"), 1) * -1000000 ) - limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size + limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size for i in range(self.max_num_seqs): q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000 - mask[self.moba_use_encoder_seq_limit - self.moba_block_size :] = -1000000 + mask[self.plas_use_encoder_seq_limit - self.plas_block_size :] = -1000000 for i in range(limit_topk - 1, self.num_blocks): n_block = i mask[ - i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, - n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, + i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size, + n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size, ] = x idx = 0 n_block -= int(qk_gate_topk_idx[i, 0, idx]) idx += 1 while n_block >= 0: mask[ - i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, - n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, + i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size, + n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size, ] = 0 n_block -= int(qk_gate_topk_idx[i, 0, idx]) idx += 1 naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2]) assert (attn_out - naive_attn_out).abs().max() < 0.016 - def test_moba_attention(self): + def test_plas_attention(self): qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16") seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32") @@ -255,7 +255,7 @@ class TestMobaAttention(unittest.TestCase): self.num_heads, self.num_kv_heads, self.head_dim, - self.moba_max_seq_length, + self.plas_max_seq_length, self.seq_len, self.seq_len, "none", @@ -307,9 +307,9 @@ class TestMobaAttention(unittest.TestCase): self.seq_len, self.num_heads, self.num_kv_heads, - self.moba_encoder_top_k_left, - self.moba_encoder_top_k_right, - self.moba_use_encoder_seq_limit, + self.plas_encoder_top_k_left, + self.plas_encoder_top_k_right, + self.plas_use_encoder_seq_limit, ) self.compare_qk_gate_topk(qk_gate_topk_idx) @@ -332,7 +332,7 @@ class TestMobaAttention(unittest.TestCase): self.num_heads, self.num_kv_heads, self.head_dim, - self.moba_max_seq_length, + self.plas_max_seq_length, ) self.compare_attn(attn_out, qk_gate_topk_idx) @@ -340,18 +340,18 @@ class TestMobaAttention(unittest.TestCase): def test_server(self): if get_cur_cu_seq_len_k is None: return - os.environ["FD_ATTENTION_BACKEND"] = "MOBA_ATTN" + os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN" base_path = os.getenv("MODEL_PATH") if base_path: model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle") else: model_path = "./ernie-4_5-21b-a3b-bf16-paddle" - moba_attention_config = { - "moba_encoder_top_k_left": 50, - "moba_encoder_top_k_right": 60, - "moba_decoder_top_k_left": 100, - "moba_decoder_top_k_right": 120, + plas_attention_config = { + "plas_encoder_top_k_left": 50, + "plas_encoder_top_k_right": 60, + "plas_decoder_top_k_left": 100, + "plas_decoder_top_k_right": 120, } # 加载模型 @@ -365,7 +365,7 @@ class TestMobaAttention(unittest.TestCase): quantization="wint4", enable_chunked_prefill=True, max_num_batched_tokens=8192, - moba_attention_config=moba_attention_config, + plas_attention_config=plas_attention_config, ) prompts = ["Hello world!"] diff --git a/tests/model_loader/test_load_attention.py b/tests/model_loader/test_load_attention.py index d35d6e442..cb3f8d89e 100644 --- a/tests/model_loader/test_load_attention.py +++ b/tests/model_loader/test_load_attention.py @@ -65,7 +65,7 @@ class TestAttentionInitWeight(unittest.TestCase): self.fd_config.parallel_config = self.parallel_config self.fd_config.cache_config = self.cache_config self.fd_config.quant_config = None - self.fd_config.moba_attention_config = None + self.fd_config.plas_attention_config = None def test_init_weight_without_quantization(self): """Test init_weight without quantization.""" @@ -141,7 +141,7 @@ class TestAttentionWeightLoader(unittest.TestCase): self.fd_config.model_config = self.model_config self.fd_config.parallel_config = self.parallel_config self.fd_config.cache_config = self.cache_config - self.fd_config.moba_attention_config = None + self.fd_config.plas_attention_config = None # Create mock quant method self.mock_quant_method = MockQuantMethod()