【FIX】Change the name of sparse attn from moba to plas (#4006) (#4076)

* 【FIX】Change the name of sparse attn from moba to plas (#4006)

* 更新文档

* 【docs】 update readme (#4000)

* 更新文档

* update readme

* update docs

* 【FIX】Change the name of sparse attn from moba to plas (#3845)

* 更新文档

* 更新文档

* 更新文档

* 更新文档

* 修改moba为plas

* code style

* update ci

* code style

* update ci

* code style

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>

* fix max_num_seqs

* fix test load attn

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
yangjianfengo1
2025-09-23 10:26:40 +08:00
committed by GitHub
parent 2c34a557f4
commit 4325b737e7
14 changed files with 152 additions and 152 deletions

View File

@@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev
## Usage ## Usage
``` ```
export FD_ATTENTION_BACKEND="MOBA_ATTN" export FD_ATTENTION_BACKEND="PLAS_ATTN"
python -m fastdeploy.entrypoints.openai.api_server python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \ --model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server
--max-num-batched-tokens 8192 \ --max-num-batched-tokens 8192 \
--max-model-len 131072 \ --max-model-len 131072 \
--max-num-seqs 32 \ --max-num-seqs 32 \
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' --plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
``` ```
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations. **Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
**Parameter Description:** **Parameter Description:**
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention. * Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse. * `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse. * `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.

View File

@@ -18,7 +18,7 @@
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%"> <img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
</div> </div>
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异引导 MLP 层学习更符合真实注意力分布的特征表示。 * **Attention Gate Module**: 如上图所示为了以较低的计算开销估计每个块的重要性我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异引导 MLP 层学习更符合真实注意力分布的特征表示。
* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。 * **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
@@ -36,7 +36,7 @@
* **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。 * **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。 * **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用我们发现同一组内的不同查询头经常选择重叠的关键块。因此我们将同一组内所有查询头选择的关键块合并为一个统一的集合并联合计算稀疏注意力机制。这种方式也减少了内存访问开销并进一步提高了解码效率。
* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。 * **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
@@ -200,7 +200,7 @@
## 使用方式 ## 使用方式
``` ```
export FD_ATTENTION_BACKEND="MOBA_ATTN" export FD_ATTENTION_BACKEND="PLAS_ATTN"
python -m fastdeploy.entrypoints.openai.api_server python -m fastdeploy.entrypoints.openai.api_server
--model baidu/ERNIE-4.5-300B-A47B-Paddle \ --model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -211,13 +211,13 @@ python -m fastdeploy.entrypoints.openai.api_server
--max-num-batched-tokens 8192 \ --max-num-batched-tokens 8192 \
--max-model-len 131072 \ --max-model-len 131072 \
--max-num-seqs 32 \ --max-num-seqs 32 \
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}' --plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
``` ```
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化 **Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
**Parameter Description:** **Parameter Description:**
* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention. * `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention.
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时top-k的范围在50到60之间。 * `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时top-k的范围在50到60之间。
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时top-k的范围在100到120之间。 * `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时top-k的范围在100到120之间。

View File

@@ -945,63 +945,63 @@ class GraphOptimizationConfig:
argument = self.use_cudagraph argument = self.use_cudagraph
class MobaAttentionConfig: class PlasAttentionConfig:
def __init__( def __init__(
self, self,
args, args,
): ):
self.moba_encoder_top_k_left: int = None self.plas_encoder_top_k_left: int = None
self.moba_encoder_top_k_right: int = None self.plas_encoder_top_k_right: int = None
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]" "The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
self.moba_decoder_top_k_left: int = None self.plas_decoder_top_k_left: int = None
self.moba_decoder_top_k_right: int = None self.plas_decoder_top_k_right: int = None
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]" "The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
self.moba_use_encoder_seq_limit: int = None self.plas_use_encoder_seq_limit: int = None
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse" "When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
self.moba_use_decoder_seq_limit: int = None self.plas_use_decoder_seq_limit: int = None
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse" "When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
self.moba_block_size: int = 128 self.plas_block_size: int = 128
self.mlp_weight_name: str = "moba_mlp_weight.safetensors" self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
self.moba_max_seq_length: int = 128 * 1024 self.plas_max_seq_length: int = 128 * 1024
if args is not None: if args is not None:
for key, value in args.items(): for key, value in args.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None: if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None: if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
self.check_legality_parameters() self.check_legality_parameters()
def check_legality_parameters( def check_legality_parameters(
self, self,
) -> None: ) -> None:
if self.moba_encoder_top_k_left is not None: if self.plas_encoder_top_k_left is not None:
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0" assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
if self.moba_encoder_top_k_right is not None: if self.plas_encoder_top_k_right is not None:
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0" assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
assert ( assert (
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left" ), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
if self.moba_decoder_top_k_left is not None: if self.plas_decoder_top_k_left is not None:
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0" assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
if self.moba_decoder_top_k_right is not None: if self.plas_decoder_top_k_right is not None:
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0" assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
assert ( assert (
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left" ), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None: if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None: if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
def to_json_string(self): def to_json_string(self):
""" """
Convert moba_attention_config to json string. Convert plas_attention_config to json string.
""" """
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None}) return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
@@ -1396,7 +1396,7 @@ class FDConfig:
decoding_config: DecodingConfig = None, decoding_config: DecodingConfig = None,
quant_config: QuantConfigBase = None, quant_config: QuantConfigBase = None,
graph_opt_config: GraphOptimizationConfig = None, graph_opt_config: GraphOptimizationConfig = None,
moba_attention_config: MobaAttentionConfig = None, plas_attention_config: PlasAttentionConfig = None,
speculative_config: SpeculativeConfig = None, speculative_config: SpeculativeConfig = None,
tokenizer: str = None, tokenizer: str = None,
max_model_len: int = 8192, max_model_len: int = 8192,
@@ -1427,7 +1427,7 @@ class FDConfig:
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
self.decoding_config: DecodingConfig = decoding_config # type: ignore self.decoding_config: DecodingConfig = decoding_config # type: ignore
self.cache_config: CacheConfig = cache_config # type: ignore self.cache_config: CacheConfig = cache_config # type: ignore
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
# Initialize cuda graph capture list # Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None: if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs) self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs)

View File

@@ -30,9 +30,9 @@ from fastdeploy.config import (
FDConfig, FDConfig,
GraphOptimizationConfig, GraphOptimizationConfig,
LoadConfig, LoadConfig,
MobaAttentionConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
PlasAttentionConfig,
PoolerConfig, PoolerConfig,
RunnerOption, RunnerOption,
SpeculativeConfig, SpeculativeConfig,
@@ -361,9 +361,9 @@ class EngineArgs:
""" """
Configuration for graph optimization backend execution. Configuration for graph optimization backend execution.
""" """
moba_attention_config: Optional[Dict[str, Any]] = None plas_attention_config: Optional[Dict[str, Any]] = None
""" """
Configuration for moba attention. Configuration for plas attention.
""" """
enable_logprob: bool = False enable_logprob: bool = False
@@ -601,9 +601,9 @@ class EngineArgs:
help="", help="",
) )
model_group.add_argument( model_group.add_argument(
"--moba-attention-config", "--plas-attention-config",
type=json.loads, type=json.loads,
default=EngineArgs.moba_attention_config, default=EngineArgs.plas_attention_config,
help="", help="",
) )
model_group.add_argument( model_group.add_argument(
@@ -993,17 +993,17 @@ class EngineArgs:
graph_optimization_args[k] = v graph_optimization_args[k] = v
return GraphOptimizationConfig(graph_optimization_args) return GraphOptimizationConfig(graph_optimization_args)
def create_moba_attention_config(self) -> MobaAttentionConfig: def create_plas_attention_config(self) -> PlasAttentionConfig:
""" """
Create and retuan a MobaAttentionConfig object based on the current settings. Create and retuan a PlasAttentionConfig object based on the current settings.
""" """
attention_args = asdict(self) attention_args = asdict(self)
if self.moba_attention_config is not None: if self.plas_attention_config is not None:
for k, v in self.moba_attention_config.items(): for k, v in self.plas_attention_config.items():
attention_args[k] = v attention_args[k] = v
return MobaAttentionConfig(attention_args) return PlasAttentionConfig(attention_args)
else: else:
return MobaAttentionConfig(None) return PlasAttentionConfig(None)
def create_early_stop_config(self) -> EarlyStopConfig: def create_early_stop_config(self) -> EarlyStopConfig:
""" """
@@ -1064,7 +1064,7 @@ class EngineArgs:
scheduler_cfg = self.create_scheduler_config() scheduler_cfg = self.create_scheduler_config()
graph_opt_cfg = self.create_graph_optimization_config() graph_opt_cfg = self.create_graph_optimization_config()
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph) graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
moba_attention_config = self.create_moba_attention_config() plas_attention_config = self.create_plas_attention_config()
early_stop_cfg = self.create_early_stop_config() early_stop_cfg = self.create_early_stop_config()
early_stop_cfg.update_enable_early_stop(self.enable_early_stop) early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1093,7 +1093,7 @@ class EngineArgs:
max_long_partial_prefills=self.max_long_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold, long_prefill_token_threshold=self.long_prefill_token_threshold,
graph_opt_config=graph_opt_cfg, graph_opt_config=graph_opt_cfg,
moba_attention_config=moba_attention_config, plas_attention_config=plas_attention_config,
guided_decoding_backend=self.guided_decoding_backend, guided_decoding_backend=self.guided_decoding_backend,
disable_any_whitespace=self.guided_decoding_disable_any_whitespace, disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
early_stop_config=early_stop_cfg, early_stop_config=early_stop_cfg,

View File

@@ -501,7 +501,7 @@ class LLMEngine:
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'" f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
f" --reasoning_parser {self.cfg.reasoning_parser}" f" --reasoning_parser {self.cfg.reasoning_parser}"
f" --load_choices {self.cfg.load_config.load_choices}" f" --load_choices {self.cfg.load_config.load_choices}"
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'" f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
f" --ips {ips}" f" --ips {ips}"
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}" f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
f" --runner {self.cfg.model_config.runner}" f" --runner {self.cfg.model_config.runner}"

View File

@@ -20,7 +20,7 @@ from .block_multihead_attn_backend import BlockAttentionBackend
from .flash_attn_backend import FlashAttentionBackend from .flash_attn_backend import FlashAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import MobaAttentionBackend from .moba_attention_backend import PlasAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend from .xpu_attn_backend import XPUAttentionBackend
@@ -35,5 +35,5 @@ __all__ = [
"IluvatarAttnBackend", "IluvatarAttnBackend",
"BlockAttentionBackend", "BlockAttentionBackend",
"Attention", "Attention",
"MobaAttentionBackend", "PlasAttentionBackend",
] ]

View File

@@ -119,19 +119,19 @@ class Attention(nn.Layer):
self.init_weight() self.init_weight()
if ( if (
fd_config.moba_attention_config is not None fd_config.plas_attention_config is not None
and fd_config.moba_attention_config.moba_encoder_top_k_left is not None and fd_config.plas_attention_config.plas_encoder_top_k_left is not None
and fd_config.moba_attention_config.moba_encoder_top_k_right is not None and fd_config.plas_attention_config.plas_encoder_top_k_right is not None
and fd_config.moba_attention_config.moba_decoder_top_k_left is not None and fd_config.plas_attention_config.plas_decoder_top_k_left is not None
and fd_config.moba_attention_config.moba_decoder_top_k_right is not None and fd_config.plas_attention_config.plas_decoder_top_k_right is not None
): ):
mlp_weight_path = os.path.join( mlp_weight_path = os.path.join(
fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name fd_config.model_config.model, fd_config.plas_attention_config.mlp_weight_name
) )
self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path) self.plas_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
moba_block_size = fd_config.moba_attention_config.moba_block_size plas_block_size = fd_config.plas_attention_config.plas_block_size
moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
if self.moba_use_mlp: if self.plas_use_mlp:
mlp_weight = {} mlp_weight = {}
with safe_open(mlp_weight_path, framework="np", device="cpu") as f: with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
for key_name in f.keys(): for key_name in f.keys():
@@ -148,12 +148,12 @@ class Attention(nn.Layer):
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1) * self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
* self.kv_num_heads * self.kv_num_heads
] ]
assert self.attn_gate_weight.shape[1] % moba_block_size == 0 assert self.attn_gate_weight.shape[1] % plas_block_size == 0
self.cache_k_block_means = paddle.zeros( self.cache_k_block_means = paddle.zeros(
[ [
fd_config.scheduler_config.max_num_seqs, fd_config.scheduler_config.max_num_seqs,
moba_max_seq_length // moba_block_size, plas_max_seq_length // plas_block_size,
self.kv_num_heads, self.kv_num_heads,
self.head_dim, self.head_dim,
], ],

View File

@@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@dataclass @dataclass
class MobaAttentionMetadata(AttentionMetadata): class PlasAttentionMetadata(AttentionMetadata):
""" """
AppendAttentionMetadata AppendAttentionMetadata
""" """
@@ -54,7 +54,7 @@ class MobaAttentionMetadata(AttentionMetadata):
max_dec_len_this_time: int = 0 max_dec_len_this_time: int = 0
class MobaAttentionBackend(AttentionBackend): class PlasAttentionBackend(AttentionBackend):
""" """
The backend class that uses paddle native attention implementation. The backend class that uses paddle native attention implementation.
Which is used only for testing purpose. Which is used only for testing purpose.
@@ -70,11 +70,11 @@ class MobaAttentionBackend(AttentionBackend):
decoder_block_shape_q: int = -1, decoder_block_shape_q: int = -1,
) -> None: ) -> None:
""" """
MobaAttentionBackend __init__ PlasAttentionBackend __init__
""" """
super().__init__() super().__init__()
self.attention_metadata: MobaAttentionMetadata = None self.attention_metadata: PlasAttentionMetadata = None
assert fd_config.moba_attention_config is not None, "moba_attention_config is None" assert fd_config.plas_attention_config is not None, "plas_attention_config is None"
self.block_size = fd_config.parallel_config.block_size self.block_size = fd_config.parallel_config.block_size
self.max_seq_len = fd_config.parallel_config.max_model_len self.max_seq_len = fd_config.parallel_config.max_model_len
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
@@ -83,18 +83,18 @@ class MobaAttentionBackend(AttentionBackend):
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.attn_block_m = 128 self.attn_block_m = 128
self.moba_block_size = fd_config.moba_attention_config.moba_block_size self.plas_block_size = fd_config.plas_attention_config.plas_block_size
self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left) self.plas_encoder_top_k_left = int(fd_config.plas_attention_config.plas_encoder_top_k_left)
self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right) self.plas_encoder_top_k_right = int(fd_config.plas_attention_config.plas_encoder_top_k_right)
self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit) self.plas_use_encoder_seq_limit = int(fd_config.plas_attention_config.plas_use_encoder_seq_limit)
self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left) self.plas_decoder_top_k_left = int(fd_config.plas_attention_config.plas_decoder_top_k_left)
self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right) self.plas_decoder_top_k_right = int(fd_config.plas_attention_config.plas_decoder_top_k_right)
self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit) self.plas_use_decoder_seq_limit = int(fd_config.plas_attention_config.plas_use_decoder_seq_limit)
self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length self.plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
metadata = MobaAttentionMetadata() metadata = PlasAttentionMetadata()
metadata._dtype = paddle.get_default_dtype() metadata._dtype = paddle.get_default_dtype()
metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k( metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k(
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
@@ -116,7 +116,7 @@ class MobaAttentionBackend(AttentionBackend):
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype [k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
) )
self.attention_metadata = metadata self.attention_metadata = metadata
assert self.max_seq_len <= self.moba_max_seq_length assert self.max_seq_len <= self.plas_max_seq_length
def get_kv_cache_shape( def get_kv_cache_shape(
self, self,
@@ -186,13 +186,13 @@ class MobaAttentionBackend(AttentionBackend):
self.max_seq_len, self.max_seq_len,
attention_metadata.max_enc_len_this_time, attention_metadata.max_enc_len_this_time,
attention_metadata.max_dec_len_this_time, attention_metadata.max_dec_len_this_time,
self.moba_encoder_top_k_left, self.plas_encoder_top_k_left,
self.moba_encoder_top_k_right, self.plas_encoder_top_k_right,
self.moba_use_encoder_seq_limit, self.plas_use_encoder_seq_limit,
self.moba_decoder_top_k_left, self.plas_decoder_top_k_left,
self.moba_decoder_top_k_right, self.plas_decoder_top_k_right,
self.moba_use_decoder_seq_limit, self.plas_use_decoder_seq_limit,
layer.moba_use_mlp, layer.plas_use_mlp,
getattr(layer, "cache_quant_type_str", "none"), getattr(layer, "cache_quant_type_str", "none"),
)[0] )[0]
return out return out

View File

@@ -26,7 +26,7 @@ class _Backend(enum.Enum):
MLA_ATTN = enum.auto() MLA_ATTN = enum.auto()
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
BLOCK_ATTN = enum.auto() BLOCK_ATTN = enum.auto()
MOBA_ATTN = enum.auto() PLAS_ATTN = enum.auto()
class Platform: class Platform:

View File

@@ -64,9 +64,9 @@ class CUDAPlatform(Platform):
elif selected_backend == _Backend.FLASH_ATTN: elif selected_backend == _Backend.FLASH_ATTN:
logger.info("Using FLASH ATTN backend.") logger.info("Using FLASH ATTN backend.")
return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend" return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
elif selected_backend == _Backend.MOBA_ATTN: elif selected_backend == _Backend.PLAS_ATTN:
logger.info("Using MOBA ATTN backend.") logger.info("Using PLAS ATTN backend.")
return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend" return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend"
else: else:
raise ValueError( raise ValueError(
"Invalid attention backend you specified.\n" "Invalid attention backend you specified.\n"

View File

@@ -61,7 +61,7 @@ class RolloutModelConfig:
graph_optimization_config: str = None, graph_optimization_config: str = None,
early_stop_config: str = None, early_stop_config: str = None,
local_rank: int = 0, local_rank: int = 0,
moba_attention_config: str = None, plas_attention_config: str = None,
data_parallel_size: int = 1, data_parallel_size: int = 1,
num_nextn_predict_layers: int = 0, num_nextn_predict_layers: int = 0,
): ):
@@ -109,7 +109,7 @@ class RolloutModelConfig:
self.local_rank = local_rank self.local_rank = local_rank
self.early_stop_config = early_stop_config self.early_stop_config = early_stop_config
self.ips = None self.ips = None
self.moba_attention_config = moba_attention_config self.plas_attention_config = plas_attention_config
self.num_nextn_predict_layers = num_nextn_predict_layers self.num_nextn_predict_layers = num_nextn_predict_layers
def __str__(self): def __str__(self):

View File

@@ -35,9 +35,9 @@ from fastdeploy.config import (
FDConfig, FDConfig,
GraphOptimizationConfig, GraphOptimizationConfig,
LoadConfig, LoadConfig,
MobaAttentionConfig,
ModelConfig, ModelConfig,
ParallelConfig, ParallelConfig,
PlasAttentionConfig,
SpeculativeConfig, SpeculativeConfig,
) )
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
@@ -577,10 +577,10 @@ def parse_args():
help="Configuration of Graph optimization backend.", help="Configuration of Graph optimization backend.",
) )
parser.add_argument( parser.add_argument(
"--moba_attention_config", "--plas_attention_config",
type=json.loads, type=json.loads,
default=None, default=None,
help="Configuration of moba attention.", help="Configation of plas attention.",
) )
parser.add_argument( parser.add_argument(
"--guided_decoding_backend", "--guided_decoding_backend",
@@ -723,7 +723,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config) graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
moba_attention_config = MobaAttentionConfig(args.moba_attention_config) plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
early_stop_config = EarlyStopConfig(args.early_stop_config) early_stop_config = EarlyStopConfig(args.early_stop_config)
@@ -795,7 +795,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
cache_config=cache_config, cache_config=cache_config,
scheduler_config=scheduler_config, scheduler_config=scheduler_config,
ips=args.ips, ips=args.ips,
moba_attention_config=moba_attention_config, plas_attention_config=plas_attention_config,
) )
update_fd_config_for_mm(fd_config) update_fd_config_for_mm(fd_config)

View File

@@ -57,7 +57,7 @@ def naive_attn(q_input, k_input, v_input, mask):
return out return out
class TestMobaAttention(unittest.TestCase): class TestPlasAttention(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.seed(0) paddle.seed(0)
self.seq_len = int(8 * 1024) self.seq_len = int(8 * 1024)
@@ -65,15 +65,15 @@ class TestMobaAttention(unittest.TestCase):
self.num_kv_heads = int(1) self.num_kv_heads = int(1)
self.head_dim = int(128) self.head_dim = int(128)
self.max_num_seqs = 1 self.max_num_seqs = 1
self.moba_max_seq_length = int(128 * 1024) self.plas_max_seq_length = int(128 * 1024)
self.moba_block_size = int(128) self.plas_block_size = int(128)
self.moba_encoder_top_k_left = 2 self.plas_encoder_top_k_left = 2
self.moba_encoder_top_k_right = 3 self.plas_encoder_top_k_right = 3
self.moba_use_encoder_seq_limit = int(4 * 1024) self.plas_use_encoder_seq_limit = int(4 * 1024)
self.cache_k_block_means = paddle.zeros( self.cache_k_block_means = paddle.zeros(
[ [
self.max_num_seqs, self.max_num_seqs,
self.moba_max_seq_length // self.moba_block_size, self.plas_max_seq_length // self.plas_block_size,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
], ],
@@ -96,12 +96,12 @@ class TestMobaAttention(unittest.TestCase):
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32") self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
self.attn_gate_weight = paddle.randn( self.attn_gate_weight = paddle.randn(
[self.num_kv_heads, self.moba_block_size, self.head_dim], dtype="bfloat16" [self.num_kv_heads, self.plas_block_size, self.head_dim], dtype="bfloat16"
) )
self.gqa_group_size = self.num_heads // self.num_kv_heads self.gqa_group_size = self.num_heads // self.num_kv_heads
self.num_blocks = (self.seq_len + self.moba_block_size - 1) // self.moba_block_size self.num_blocks = (self.seq_len + self.plas_block_size - 1) // self.plas_block_size
self.sparse_step = 4 self.sparse_step = 4
@@ -115,38 +115,38 @@ class TestMobaAttention(unittest.TestCase):
for i in range(self.max_num_seqs): for i in range(self.max_num_seqs):
k_padding = paddle.zeros( k_padding = paddle.zeros(
[ [
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, (self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
], ],
dtype="bfloat16", dtype="bfloat16",
) )
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
real_k_block_means = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) real_k_block_means = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
real_k_block_means = real_k_block_means.mean(axis=1) real_k_block_means = real_k_block_means.mean(axis=1)
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]] compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003 assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
print("[consistency]Moba attention: split_qkv_rope matches.") print("[consistency]plas attention: split_qkv_rope matches.")
def compare_mlp_einsum(self, k_gate_weight): def compare_mlp_einsum(self, k_gate_weight):
for i in range(self.max_num_seqs): for i in range(self.max_num_seqs):
k_padding = paddle.zeros( k_padding = paddle.zeros(
[ [
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size, (self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
], ],
dtype="bfloat16", dtype="bfloat16",
) )
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len] k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
k_padding = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim]) k_padding = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight) real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
compute_result = k_gate_weight[i][0 : real_result.shape[0]] compute_result = k_gate_weight[i][0 : real_result.shape[0]]
assert (real_result - compute_result).abs().max() < 0.5 assert (real_result - compute_result).abs().max() < 0.5
print("[consistency]Moba attention: MLP einsum matches.") print("[consistency]plas attention: MLP einsum matches.")
def compare_qk_gemm(self, qk_gate_weight): def compare_qk_gemm(self, qk_gate_weight):
for i in range(self.max_num_seqs): for i in range(self.max_num_seqs):
@@ -170,10 +170,10 @@ class TestMobaAttention(unittest.TestCase):
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks] conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4 assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
print("[consistency]Moba attention: qk_gemm matches.") print("[consistency]plas attention: qk_gemm matches.")
def compare_qk_gate_topk(self, qk_gate_topk_idx): def compare_qk_gate_topk(self, qk_gate_topk_idx):
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
for i in range(self.max_num_seqs): for i in range(self.max_num_seqs):
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks] qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1] qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
@@ -191,40 +191,40 @@ class TestMobaAttention(unittest.TestCase):
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype) - paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
* self.sparse_step * self.sparse_step
).abs().max() < 1e-6 ).abs().max() < 1e-6
print("[consistency]Moba attention: qk_gate_topk matches.") print("[consistency]plas attention: qk_gate_topk matches.")
def compare_attn(self, attn_out, qk_gate_topk_idx): def compare_attn(self, attn_out, qk_gate_topk_idx):
x = ( x = (
paddle.tensor.triu(paddle.ones([self.moba_block_size, self.moba_block_size], dtype="bfloat16"), 1) paddle.tensor.triu(paddle.ones([self.plas_block_size, self.plas_block_size], dtype="bfloat16"), 1)
* -1000000 * -1000000
) )
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
for i in range(self.max_num_seqs): for i in range(self.max_num_seqs):
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0) v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000 mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
mask[self.moba_use_encoder_seq_limit - self.moba_block_size :] = -1000000 mask[self.plas_use_encoder_seq_limit - self.plas_block_size :] = -1000000
for i in range(limit_topk - 1, self.num_blocks): for i in range(limit_topk - 1, self.num_blocks):
n_block = i n_block = i
mask[ mask[
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
] = x ] = x
idx = 0 idx = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx]) n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1 idx += 1
while n_block >= 0: while n_block >= 0:
mask[ mask[
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size, i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size, n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
] = 0 ] = 0
n_block -= int(qk_gate_topk_idx[i, 0, idx]) n_block -= int(qk_gate_topk_idx[i, 0, idx])
idx += 1 idx += 1
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2]) naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
assert (attn_out - naive_attn_out).abs().max() < 0.016 assert (attn_out - naive_attn_out).abs().max() < 0.016
def test_moba_attention(self): def test_plas_attention(self):
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16") qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32") seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
@@ -255,7 +255,7 @@ class TestMobaAttention(unittest.TestCase):
self.num_heads, self.num_heads,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
self.moba_max_seq_length, self.plas_max_seq_length,
self.seq_len, self.seq_len,
self.seq_len, self.seq_len,
"none", "none",
@@ -307,9 +307,9 @@ class TestMobaAttention(unittest.TestCase):
self.seq_len, self.seq_len,
self.num_heads, self.num_heads,
self.num_kv_heads, self.num_kv_heads,
self.moba_encoder_top_k_left, self.plas_encoder_top_k_left,
self.moba_encoder_top_k_right, self.plas_encoder_top_k_right,
self.moba_use_encoder_seq_limit, self.plas_use_encoder_seq_limit,
) )
self.compare_qk_gate_topk(qk_gate_topk_idx) self.compare_qk_gate_topk(qk_gate_topk_idx)
@@ -332,7 +332,7 @@ class TestMobaAttention(unittest.TestCase):
self.num_heads, self.num_heads,
self.num_kv_heads, self.num_kv_heads,
self.head_dim, self.head_dim,
self.moba_max_seq_length, self.plas_max_seq_length,
) )
self.compare_attn(attn_out, qk_gate_topk_idx) self.compare_attn(attn_out, qk_gate_topk_idx)
@@ -340,18 +340,18 @@ class TestMobaAttention(unittest.TestCase):
def test_server(self): def test_server(self):
if get_cur_cu_seq_len_k is None: if get_cur_cu_seq_len_k is None:
return return
os.environ["FD_ATTENTION_BACKEND"] = "MOBA_ATTN" os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN"
base_path = os.getenv("MODEL_PATH") base_path = os.getenv("MODEL_PATH")
if base_path: if base_path:
model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle") model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle")
else: else:
model_path = "./ernie-4_5-21b-a3b-bf16-paddle" model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
moba_attention_config = { plas_attention_config = {
"moba_encoder_top_k_left": 50, "plas_encoder_top_k_left": 50,
"moba_encoder_top_k_right": 60, "plas_encoder_top_k_right": 60,
"moba_decoder_top_k_left": 100, "plas_decoder_top_k_left": 100,
"moba_decoder_top_k_right": 120, "plas_decoder_top_k_right": 120,
} }
# 加载模型 # 加载模型
@@ -365,7 +365,7 @@ class TestMobaAttention(unittest.TestCase):
quantization="wint4", quantization="wint4",
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_num_batched_tokens=8192, max_num_batched_tokens=8192,
moba_attention_config=moba_attention_config, plas_attention_config=plas_attention_config,
) )
prompts = ["Hello world!"] prompts = ["Hello world!"]

View File

@@ -65,7 +65,7 @@ class TestAttentionInitWeight(unittest.TestCase):
self.fd_config.parallel_config = self.parallel_config self.fd_config.parallel_config = self.parallel_config
self.fd_config.cache_config = self.cache_config self.fd_config.cache_config = self.cache_config
self.fd_config.quant_config = None self.fd_config.quant_config = None
self.fd_config.moba_attention_config = None self.fd_config.plas_attention_config = None
def test_init_weight_without_quantization(self): def test_init_weight_without_quantization(self):
"""Test init_weight without quantization.""" """Test init_weight without quantization."""
@@ -141,7 +141,7 @@ class TestAttentionWeightLoader(unittest.TestCase):
self.fd_config.model_config = self.model_config self.fd_config.model_config = self.model_config
self.fd_config.parallel_config = self.parallel_config self.fd_config.parallel_config = self.parallel_config
self.fd_config.cache_config = self.cache_config self.fd_config.cache_config = self.cache_config
self.fd_config.moba_attention_config = None self.fd_config.plas_attention_config = None
# Create mock quant method # Create mock quant method
self.mock_quant_method = MockQuantMethod() self.mock_quant_method = MockQuantMethod()