mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-26 20:41:53 +08:00
* 【FIX】Change the name of sparse attn from moba to plas (#4006) * 更新文档 * 【docs】 update readme (#4000) * 更新文档 * update readme * update docs * 【FIX】Change the name of sparse attn from moba to plas (#3845) * 更新文档 * 更新文档 * 更新文档 * 更新文档 * 修改moba为plas * code style * update ci * code style * update ci * code style --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com> * fix max_num_seqs * fix test load attn --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev
|
|||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
```
|
```
|
||||||
export FD_ATTENTION_BACKEND="MOBA_ATTN"
|
export FD_ATTENTION_BACKEND="PLAS_ATTN"
|
||||||
|
|
||||||
python -m fastdeploy.entrypoints.openai.api_server
|
python -m fastdeploy.entrypoints.openai.api_server
|
||||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||||
@@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server
|
|||||||
--max-num-batched-tokens 8192 \
|
--max-num-batched-tokens 8192 \
|
||||||
--max-model-len 131072 \
|
--max-model-len 131072 \
|
||||||
--max-num-seqs 32 \
|
--max-num-seqs 32 \
|
||||||
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
|
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
|
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
|
||||||
|
|
||||||
**Parameter Description:**
|
**Parameter Description:**
|
||||||
|
|
||||||
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
|
* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
|
||||||
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
|
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
|
||||||
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
|
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
|
||||||
|
@@ -18,7 +18,7 @@
|
|||||||
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
|
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
|
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
|
||||||
|
|
||||||
* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
|
* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@
|
|||||||
|
|
||||||
* **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
|
* **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
|
||||||
|
|
||||||
* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
|
* **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
|
||||||
|
|
||||||
* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
|
* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@
|
|||||||
## 使用方式
|
## 使用方式
|
||||||
|
|
||||||
```
|
```
|
||||||
export FD_ATTENTION_BACKEND="MOBA_ATTN"
|
export FD_ATTENTION_BACKEND="PLAS_ATTN"
|
||||||
|
|
||||||
python -m fastdeploy.entrypoints.openai.api_server
|
python -m fastdeploy.entrypoints.openai.api_server
|
||||||
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
|
||||||
@@ -211,13 +211,13 @@ python -m fastdeploy.entrypoints.openai.api_server
|
|||||||
--max-num-batched-tokens 8192 \
|
--max-num-batched-tokens 8192 \
|
||||||
--max-model-len 131072 \
|
--max-model-len 131072 \
|
||||||
--max-num-seqs 32 \
|
--max-num-seqs 32 \
|
||||||
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
|
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
|
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
|
||||||
|
|
||||||
**Parameter Description:**
|
**Parameter Description:**
|
||||||
|
|
||||||
* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention.
|
* `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention.
|
||||||
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
|
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
|
||||||
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
|
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
|
||||||
|
@@ -945,63 +945,63 @@ class GraphOptimizationConfig:
|
|||||||
argument = self.use_cudagraph
|
argument = self.use_cudagraph
|
||||||
|
|
||||||
|
|
||||||
class MobaAttentionConfig:
|
class PlasAttentionConfig:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
args,
|
args,
|
||||||
):
|
):
|
||||||
self.moba_encoder_top_k_left: int = None
|
self.plas_encoder_top_k_left: int = None
|
||||||
self.moba_encoder_top_k_right: int = None
|
self.plas_encoder_top_k_right: int = None
|
||||||
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
|
"The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
|
||||||
self.moba_decoder_top_k_left: int = None
|
self.plas_decoder_top_k_left: int = None
|
||||||
self.moba_decoder_top_k_right: int = None
|
self.plas_decoder_top_k_right: int = None
|
||||||
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
|
"The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
|
||||||
self.moba_use_encoder_seq_limit: int = None
|
self.plas_use_encoder_seq_limit: int = None
|
||||||
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
|
"When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
|
||||||
self.moba_use_decoder_seq_limit: int = None
|
self.plas_use_decoder_seq_limit: int = None
|
||||||
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
|
"When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
|
||||||
self.moba_block_size: int = 128
|
self.plas_block_size: int = 128
|
||||||
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
|
self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
|
||||||
self.moba_max_seq_length: int = 128 * 1024
|
self.plas_max_seq_length: int = 128 * 1024
|
||||||
if args is not None:
|
if args is not None:
|
||||||
for key, value in args.items():
|
for key, value in args.items():
|
||||||
if hasattr(self, key):
|
if hasattr(self, key):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
|
if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
|
||||||
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
|
self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
|
||||||
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
|
if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
|
||||||
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
|
self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
|
||||||
self.check_legality_parameters()
|
self.check_legality_parameters()
|
||||||
|
|
||||||
def check_legality_parameters(
|
def check_legality_parameters(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.moba_encoder_top_k_left is not None:
|
if self.plas_encoder_top_k_left is not None:
|
||||||
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"
|
assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
|
||||||
|
|
||||||
if self.moba_encoder_top_k_right is not None:
|
if self.plas_encoder_top_k_right is not None:
|
||||||
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
|
assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
|
||||||
assert (
|
assert (
|
||||||
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
|
self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
|
||||||
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"
|
), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
|
||||||
|
|
||||||
if self.moba_decoder_top_k_left is not None:
|
if self.plas_decoder_top_k_left is not None:
|
||||||
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"
|
assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
|
||||||
|
|
||||||
if self.moba_decoder_top_k_right is not None:
|
if self.plas_decoder_top_k_right is not None:
|
||||||
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
|
assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
|
||||||
assert (
|
assert (
|
||||||
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
|
self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
|
||||||
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"
|
), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
|
||||||
|
|
||||||
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
|
if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
|
||||||
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
|
assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
|
||||||
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
|
if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
|
||||||
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size
|
assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
|
||||||
|
|
||||||
def to_json_string(self):
|
def to_json_string(self):
|
||||||
"""
|
"""
|
||||||
Convert moba_attention_config to json string.
|
Convert plas_attention_config to json string.
|
||||||
"""
|
"""
|
||||||
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
|
||||||
|
|
||||||
@@ -1396,7 +1396,7 @@ class FDConfig:
|
|||||||
decoding_config: DecodingConfig = None,
|
decoding_config: DecodingConfig = None,
|
||||||
quant_config: QuantConfigBase = None,
|
quant_config: QuantConfigBase = None,
|
||||||
graph_opt_config: GraphOptimizationConfig = None,
|
graph_opt_config: GraphOptimizationConfig = None,
|
||||||
moba_attention_config: MobaAttentionConfig = None,
|
plas_attention_config: PlasAttentionConfig = None,
|
||||||
speculative_config: SpeculativeConfig = None,
|
speculative_config: SpeculativeConfig = None,
|
||||||
tokenizer: str = None,
|
tokenizer: str = None,
|
||||||
max_model_len: int = 8192,
|
max_model_len: int = 8192,
|
||||||
@@ -1427,7 +1427,7 @@ class FDConfig:
|
|||||||
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
|
||||||
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
self.decoding_config: DecodingConfig = decoding_config # type: ignore
|
||||||
self.cache_config: CacheConfig = cache_config # type: ignore
|
self.cache_config: CacheConfig = cache_config # type: ignore
|
||||||
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
|
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
|
||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs)
|
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs)
|
||||||
|
@@ -30,9 +30,9 @@ from fastdeploy.config import (
|
|||||||
FDConfig,
|
FDConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
LoadConfig,
|
LoadConfig,
|
||||||
MobaAttentionConfig,
|
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
|
PlasAttentionConfig,
|
||||||
PoolerConfig,
|
PoolerConfig,
|
||||||
RunnerOption,
|
RunnerOption,
|
||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
@@ -361,9 +361,9 @@ class EngineArgs:
|
|||||||
"""
|
"""
|
||||||
Configuration for graph optimization backend execution.
|
Configuration for graph optimization backend execution.
|
||||||
"""
|
"""
|
||||||
moba_attention_config: Optional[Dict[str, Any]] = None
|
plas_attention_config: Optional[Dict[str, Any]] = None
|
||||||
"""
|
"""
|
||||||
Configuration for moba attention.
|
Configuration for plas attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
enable_logprob: bool = False
|
enable_logprob: bool = False
|
||||||
@@ -601,9 +601,9 @@ class EngineArgs:
|
|||||||
help="",
|
help="",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--moba-attention-config",
|
"--plas-attention-config",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
default=EngineArgs.moba_attention_config,
|
default=EngineArgs.plas_attention_config,
|
||||||
help="",
|
help="",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
@@ -993,17 +993,17 @@ class EngineArgs:
|
|||||||
graph_optimization_args[k] = v
|
graph_optimization_args[k] = v
|
||||||
return GraphOptimizationConfig(graph_optimization_args)
|
return GraphOptimizationConfig(graph_optimization_args)
|
||||||
|
|
||||||
def create_moba_attention_config(self) -> MobaAttentionConfig:
|
def create_plas_attention_config(self) -> PlasAttentionConfig:
|
||||||
"""
|
"""
|
||||||
Create and retuan a MobaAttentionConfig object based on the current settings.
|
Create and retuan a PlasAttentionConfig object based on the current settings.
|
||||||
"""
|
"""
|
||||||
attention_args = asdict(self)
|
attention_args = asdict(self)
|
||||||
if self.moba_attention_config is not None:
|
if self.plas_attention_config is not None:
|
||||||
for k, v in self.moba_attention_config.items():
|
for k, v in self.plas_attention_config.items():
|
||||||
attention_args[k] = v
|
attention_args[k] = v
|
||||||
return MobaAttentionConfig(attention_args)
|
return PlasAttentionConfig(attention_args)
|
||||||
else:
|
else:
|
||||||
return MobaAttentionConfig(None)
|
return PlasAttentionConfig(None)
|
||||||
|
|
||||||
def create_early_stop_config(self) -> EarlyStopConfig:
|
def create_early_stop_config(self) -> EarlyStopConfig:
|
||||||
"""
|
"""
|
||||||
@@ -1064,7 +1064,7 @@ class EngineArgs:
|
|||||||
scheduler_cfg = self.create_scheduler_config()
|
scheduler_cfg = self.create_scheduler_config()
|
||||||
graph_opt_cfg = self.create_graph_optimization_config()
|
graph_opt_cfg = self.create_graph_optimization_config()
|
||||||
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
|
||||||
moba_attention_config = self.create_moba_attention_config()
|
plas_attention_config = self.create_plas_attention_config()
|
||||||
|
|
||||||
early_stop_cfg = self.create_early_stop_config()
|
early_stop_cfg = self.create_early_stop_config()
|
||||||
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
|
||||||
@@ -1093,7 +1093,7 @@ class EngineArgs:
|
|||||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||||
graph_opt_config=graph_opt_cfg,
|
graph_opt_config=graph_opt_cfg,
|
||||||
moba_attention_config=moba_attention_config,
|
plas_attention_config=plas_attention_config,
|
||||||
guided_decoding_backend=self.guided_decoding_backend,
|
guided_decoding_backend=self.guided_decoding_backend,
|
||||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||||
early_stop_config=early_stop_cfg,
|
early_stop_config=early_stop_cfg,
|
||||||
|
@@ -501,7 +501,7 @@ class LLMEngine:
|
|||||||
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
|
||||||
f" --reasoning_parser {self.cfg.reasoning_parser}"
|
f" --reasoning_parser {self.cfg.reasoning_parser}"
|
||||||
f" --load_choices {self.cfg.load_config.load_choices}"
|
f" --load_choices {self.cfg.load_config.load_choices}"
|
||||||
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
|
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
|
||||||
f" --ips {ips}"
|
f" --ips {ips}"
|
||||||
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
|
||||||
f" --runner {self.cfg.model_config.runner}"
|
f" --runner {self.cfg.model_config.runner}"
|
||||||
|
@@ -20,7 +20,7 @@ from .block_multihead_attn_backend import BlockAttentionBackend
|
|||||||
from .flash_attn_backend import FlashAttentionBackend
|
from .flash_attn_backend import FlashAttentionBackend
|
||||||
from .iluvatar_attn_backend import IluvatarAttnBackend
|
from .iluvatar_attn_backend import IluvatarAttnBackend
|
||||||
from .mla_attention_backend import MLAAttentionBackend
|
from .mla_attention_backend import MLAAttentionBackend
|
||||||
from .moba_attention_backend import MobaAttentionBackend
|
from .moba_attention_backend import PlasAttentionBackend
|
||||||
from .native_paddle_backend import PaddleNativeAttnBackend
|
from .native_paddle_backend import PaddleNativeAttnBackend
|
||||||
from .xpu_attn_backend import XPUAttentionBackend
|
from .xpu_attn_backend import XPUAttentionBackend
|
||||||
|
|
||||||
@@ -35,5 +35,5 @@ __all__ = [
|
|||||||
"IluvatarAttnBackend",
|
"IluvatarAttnBackend",
|
||||||
"BlockAttentionBackend",
|
"BlockAttentionBackend",
|
||||||
"Attention",
|
"Attention",
|
||||||
"MobaAttentionBackend",
|
"PlasAttentionBackend",
|
||||||
]
|
]
|
||||||
|
@@ -119,19 +119,19 @@ class Attention(nn.Layer):
|
|||||||
|
|
||||||
self.init_weight()
|
self.init_weight()
|
||||||
if (
|
if (
|
||||||
fd_config.moba_attention_config is not None
|
fd_config.plas_attention_config is not None
|
||||||
and fd_config.moba_attention_config.moba_encoder_top_k_left is not None
|
and fd_config.plas_attention_config.plas_encoder_top_k_left is not None
|
||||||
and fd_config.moba_attention_config.moba_encoder_top_k_right is not None
|
and fd_config.plas_attention_config.plas_encoder_top_k_right is not None
|
||||||
and fd_config.moba_attention_config.moba_decoder_top_k_left is not None
|
and fd_config.plas_attention_config.plas_decoder_top_k_left is not None
|
||||||
and fd_config.moba_attention_config.moba_decoder_top_k_right is not None
|
and fd_config.plas_attention_config.plas_decoder_top_k_right is not None
|
||||||
):
|
):
|
||||||
mlp_weight_path = os.path.join(
|
mlp_weight_path = os.path.join(
|
||||||
fd_config.model_config.model, fd_config.moba_attention_config.mlp_weight_name
|
fd_config.model_config.model, fd_config.plas_attention_config.mlp_weight_name
|
||||||
)
|
)
|
||||||
self.moba_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
|
self.plas_use_mlp = mlp_weight_path is not None and os.path.exists(mlp_weight_path)
|
||||||
moba_block_size = fd_config.moba_attention_config.moba_block_size
|
plas_block_size = fd_config.plas_attention_config.plas_block_size
|
||||||
moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
|
plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
|
||||||
if self.moba_use_mlp:
|
if self.plas_use_mlp:
|
||||||
mlp_weight = {}
|
mlp_weight = {}
|
||||||
with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
|
with safe_open(mlp_weight_path, framework="np", device="cpu") as f:
|
||||||
for key_name in f.keys():
|
for key_name in f.keys():
|
||||||
@@ -148,12 +148,12 @@ class Attention(nn.Layer):
|
|||||||
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
|
* self.kv_num_heads : (fd_config.parallel_config.tensor_parallel_rank + 1)
|
||||||
* self.kv_num_heads
|
* self.kv_num_heads
|
||||||
]
|
]
|
||||||
assert self.attn_gate_weight.shape[1] % moba_block_size == 0
|
assert self.attn_gate_weight.shape[1] % plas_block_size == 0
|
||||||
|
|
||||||
self.cache_k_block_means = paddle.zeros(
|
self.cache_k_block_means = paddle.zeros(
|
||||||
[
|
[
|
||||||
fd_config.scheduler_config.max_num_seqs,
|
fd_config.scheduler_config.max_num_seqs,
|
||||||
moba_max_seq_length // moba_block_size,
|
plas_max_seq_length // plas_block_size,
|
||||||
self.kv_num_heads,
|
self.kv_num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
],
|
],
|
||||||
|
@@ -39,7 +39,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MobaAttentionMetadata(AttentionMetadata):
|
class PlasAttentionMetadata(AttentionMetadata):
|
||||||
"""
|
"""
|
||||||
AppendAttentionMetadata
|
AppendAttentionMetadata
|
||||||
"""
|
"""
|
||||||
@@ -54,7 +54,7 @@ class MobaAttentionMetadata(AttentionMetadata):
|
|||||||
max_dec_len_this_time: int = 0
|
max_dec_len_this_time: int = 0
|
||||||
|
|
||||||
|
|
||||||
class MobaAttentionBackend(AttentionBackend):
|
class PlasAttentionBackend(AttentionBackend):
|
||||||
"""
|
"""
|
||||||
The backend class that uses paddle native attention implementation.
|
The backend class that uses paddle native attention implementation.
|
||||||
Which is used only for testing purpose.
|
Which is used only for testing purpose.
|
||||||
@@ -70,11 +70,11 @@ class MobaAttentionBackend(AttentionBackend):
|
|||||||
decoder_block_shape_q: int = -1,
|
decoder_block_shape_q: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
MobaAttentionBackend __init__
|
PlasAttentionBackend __init__
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention_metadata: MobaAttentionMetadata = None
|
self.attention_metadata: PlasAttentionMetadata = None
|
||||||
assert fd_config.moba_attention_config is not None, "moba_attention_config is None"
|
assert fd_config.plas_attention_config is not None, "plas_attention_config is None"
|
||||||
self.block_size = fd_config.parallel_config.block_size
|
self.block_size = fd_config.parallel_config.block_size
|
||||||
self.max_seq_len = fd_config.parallel_config.max_model_len
|
self.max_seq_len = fd_config.parallel_config.max_model_len
|
||||||
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
self.max_num_seqs = fd_config.scheduler_config.max_num_seqs
|
||||||
@@ -83,18 +83,18 @@ class MobaAttentionBackend(AttentionBackend):
|
|||||||
self.head_dim = fd_config.model_config.head_dim
|
self.head_dim = fd_config.model_config.head_dim
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
self.attn_block_m = 128
|
self.attn_block_m = 128
|
||||||
self.moba_block_size = fd_config.moba_attention_config.moba_block_size
|
self.plas_block_size = fd_config.plas_attention_config.plas_block_size
|
||||||
self.moba_encoder_top_k_left = int(fd_config.moba_attention_config.moba_encoder_top_k_left)
|
self.plas_encoder_top_k_left = int(fd_config.plas_attention_config.plas_encoder_top_k_left)
|
||||||
self.moba_encoder_top_k_right = int(fd_config.moba_attention_config.moba_encoder_top_k_right)
|
self.plas_encoder_top_k_right = int(fd_config.plas_attention_config.plas_encoder_top_k_right)
|
||||||
self.moba_use_encoder_seq_limit = int(fd_config.moba_attention_config.moba_use_encoder_seq_limit)
|
self.plas_use_encoder_seq_limit = int(fd_config.plas_attention_config.plas_use_encoder_seq_limit)
|
||||||
self.moba_decoder_top_k_left = int(fd_config.moba_attention_config.moba_decoder_top_k_left)
|
self.plas_decoder_top_k_left = int(fd_config.plas_attention_config.plas_decoder_top_k_left)
|
||||||
self.moba_decoder_top_k_right = int(fd_config.moba_attention_config.moba_decoder_top_k_right)
|
self.plas_decoder_top_k_right = int(fd_config.plas_attention_config.plas_decoder_top_k_right)
|
||||||
self.moba_use_decoder_seq_limit = int(fd_config.moba_attention_config.moba_use_decoder_seq_limit)
|
self.plas_use_decoder_seq_limit = int(fd_config.plas_attention_config.plas_use_decoder_seq_limit)
|
||||||
self.moba_max_seq_length = fd_config.moba_attention_config.moba_max_seq_length
|
self.plas_max_seq_length = fd_config.plas_attention_config.plas_max_seq_length
|
||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Init the metadata for a forward pass."""
|
"""Init the metadata for a forward pass."""
|
||||||
metadata = MobaAttentionMetadata()
|
metadata = PlasAttentionMetadata()
|
||||||
metadata._dtype = paddle.get_default_dtype()
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k(
|
metadata.cu_seq_q_pack, metadata.cu_seqlens_k, metadata.q_pack_tokens = get_cur_cu_seq_len_k(
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
@@ -116,7 +116,7 @@ class MobaAttentionBackend(AttentionBackend):
|
|||||||
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
|
[k_token_num + self.attn_block_m, self.kv_num_heads * self.head_dim], dtype=metadata._dtype
|
||||||
)
|
)
|
||||||
self.attention_metadata = metadata
|
self.attention_metadata = metadata
|
||||||
assert self.max_seq_len <= self.moba_max_seq_length
|
assert self.max_seq_len <= self.plas_max_seq_length
|
||||||
|
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
self,
|
self,
|
||||||
@@ -186,13 +186,13 @@ class MobaAttentionBackend(AttentionBackend):
|
|||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
attention_metadata.max_enc_len_this_time,
|
attention_metadata.max_enc_len_this_time,
|
||||||
attention_metadata.max_dec_len_this_time,
|
attention_metadata.max_dec_len_this_time,
|
||||||
self.moba_encoder_top_k_left,
|
self.plas_encoder_top_k_left,
|
||||||
self.moba_encoder_top_k_right,
|
self.plas_encoder_top_k_right,
|
||||||
self.moba_use_encoder_seq_limit,
|
self.plas_use_encoder_seq_limit,
|
||||||
self.moba_decoder_top_k_left,
|
self.plas_decoder_top_k_left,
|
||||||
self.moba_decoder_top_k_right,
|
self.plas_decoder_top_k_right,
|
||||||
self.moba_use_decoder_seq_limit,
|
self.plas_use_decoder_seq_limit,
|
||||||
layer.moba_use_mlp,
|
layer.plas_use_mlp,
|
||||||
getattr(layer, "cache_quant_type_str", "none"),
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
)[0]
|
)[0]
|
||||||
return out
|
return out
|
||||||
|
@@ -26,7 +26,7 @@ class _Backend(enum.Enum):
|
|||||||
MLA_ATTN = enum.auto()
|
MLA_ATTN = enum.auto()
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
BLOCK_ATTN = enum.auto()
|
BLOCK_ATTN = enum.auto()
|
||||||
MOBA_ATTN = enum.auto()
|
PLAS_ATTN = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class Platform:
|
class Platform:
|
||||||
|
@@ -64,9 +64,9 @@ class CUDAPlatform(Platform):
|
|||||||
elif selected_backend == _Backend.FLASH_ATTN:
|
elif selected_backend == _Backend.FLASH_ATTN:
|
||||||
logger.info("Using FLASH ATTN backend.")
|
logger.info("Using FLASH ATTN backend.")
|
||||||
return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
|
return "fastdeploy.model_executor.layers.attention.FlashAttentionBackend"
|
||||||
elif selected_backend == _Backend.MOBA_ATTN:
|
elif selected_backend == _Backend.PLAS_ATTN:
|
||||||
logger.info("Using MOBA ATTN backend.")
|
logger.info("Using PLAS ATTN backend.")
|
||||||
return "fastdeploy.model_executor.layers.attention.MobaAttentionBackend"
|
return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid attention backend you specified.\n"
|
"Invalid attention backend you specified.\n"
|
||||||
|
@@ -61,7 +61,7 @@ class RolloutModelConfig:
|
|||||||
graph_optimization_config: str = None,
|
graph_optimization_config: str = None,
|
||||||
early_stop_config: str = None,
|
early_stop_config: str = None,
|
||||||
local_rank: int = 0,
|
local_rank: int = 0,
|
||||||
moba_attention_config: str = None,
|
plas_attention_config: str = None,
|
||||||
data_parallel_size: int = 1,
|
data_parallel_size: int = 1,
|
||||||
num_nextn_predict_layers: int = 0,
|
num_nextn_predict_layers: int = 0,
|
||||||
):
|
):
|
||||||
@@ -109,7 +109,7 @@ class RolloutModelConfig:
|
|||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.early_stop_config = early_stop_config
|
self.early_stop_config = early_stop_config
|
||||||
self.ips = None
|
self.ips = None
|
||||||
self.moba_attention_config = moba_attention_config
|
self.plas_attention_config = plas_attention_config
|
||||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@@ -35,9 +35,9 @@ from fastdeploy.config import (
|
|||||||
FDConfig,
|
FDConfig,
|
||||||
GraphOptimizationConfig,
|
GraphOptimizationConfig,
|
||||||
LoadConfig,
|
LoadConfig,
|
||||||
MobaAttentionConfig,
|
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
|
PlasAttentionConfig,
|
||||||
SpeculativeConfig,
|
SpeculativeConfig,
|
||||||
)
|
)
|
||||||
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
from fastdeploy.input.ernie4_5_tokenizer import Ernie4_5Tokenizer
|
||||||
@@ -577,10 +577,10 @@ def parse_args():
|
|||||||
help="Configuration of Graph optimization backend.",
|
help="Configuration of Graph optimization backend.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--moba_attention_config",
|
"--plas_attention_config",
|
||||||
type=json.loads,
|
type=json.loads,
|
||||||
default=None,
|
default=None,
|
||||||
help="Configuration of moba attention.",
|
help="Configation of plas attention.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guided_decoding_backend",
|
"--guided_decoding_backend",
|
||||||
@@ -723,7 +723,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
|
|
||||||
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
graph_opt_config = GraphOptimizationConfig(args.graph_optimization_config)
|
||||||
|
|
||||||
moba_attention_config = MobaAttentionConfig(args.moba_attention_config)
|
plas_attention_config = PlasAttentionConfig(args.plas_attention_config)
|
||||||
|
|
||||||
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
early_stop_config = EarlyStopConfig(args.early_stop_config)
|
||||||
|
|
||||||
@@ -795,7 +795,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
ips=args.ips,
|
ips=args.ips,
|
||||||
moba_attention_config=moba_attention_config,
|
plas_attention_config=plas_attention_config,
|
||||||
)
|
)
|
||||||
update_fd_config_for_mm(fd_config)
|
update_fd_config_for_mm(fd_config)
|
||||||
|
|
||||||
|
@@ -57,7 +57,7 @@ def naive_attn(q_input, k_input, v_input, mask):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class TestMobaAttention(unittest.TestCase):
|
class TestPlasAttention(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
paddle.seed(0)
|
paddle.seed(0)
|
||||||
self.seq_len = int(8 * 1024)
|
self.seq_len = int(8 * 1024)
|
||||||
@@ -65,15 +65,15 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
self.num_kv_heads = int(1)
|
self.num_kv_heads = int(1)
|
||||||
self.head_dim = int(128)
|
self.head_dim = int(128)
|
||||||
self.max_num_seqs = 1
|
self.max_num_seqs = 1
|
||||||
self.moba_max_seq_length = int(128 * 1024)
|
self.plas_max_seq_length = int(128 * 1024)
|
||||||
self.moba_block_size = int(128)
|
self.plas_block_size = int(128)
|
||||||
self.moba_encoder_top_k_left = 2
|
self.plas_encoder_top_k_left = 2
|
||||||
self.moba_encoder_top_k_right = 3
|
self.plas_encoder_top_k_right = 3
|
||||||
self.moba_use_encoder_seq_limit = int(4 * 1024)
|
self.plas_use_encoder_seq_limit = int(4 * 1024)
|
||||||
self.cache_k_block_means = paddle.zeros(
|
self.cache_k_block_means = paddle.zeros(
|
||||||
[
|
[
|
||||||
self.max_num_seqs,
|
self.max_num_seqs,
|
||||||
self.moba_max_seq_length // self.moba_block_size,
|
self.plas_max_seq_length // self.plas_block_size,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
],
|
],
|
||||||
@@ -96,12 +96,12 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
|
self.rotary_embs = paddle.ones([2, self.seq_len, self.head_dim // 2], dtype="float32")
|
||||||
|
|
||||||
self.attn_gate_weight = paddle.randn(
|
self.attn_gate_weight = paddle.randn(
|
||||||
[self.num_kv_heads, self.moba_block_size, self.head_dim], dtype="bfloat16"
|
[self.num_kv_heads, self.plas_block_size, self.head_dim], dtype="bfloat16"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gqa_group_size = self.num_heads // self.num_kv_heads
|
self.gqa_group_size = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
self.num_blocks = (self.seq_len + self.moba_block_size - 1) // self.moba_block_size
|
self.num_blocks = (self.seq_len + self.plas_block_size - 1) // self.plas_block_size
|
||||||
|
|
||||||
self.sparse_step = 4
|
self.sparse_step = 4
|
||||||
|
|
||||||
@@ -115,38 +115,38 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
for i in range(self.max_num_seqs):
|
for i in range(self.max_num_seqs):
|
||||||
k_padding = paddle.zeros(
|
k_padding = paddle.zeros(
|
||||||
[
|
[
|
||||||
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size,
|
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
],
|
],
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
)
|
)
|
||||||
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
||||||
real_k_block_means = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim])
|
real_k_block_means = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
|
||||||
real_k_block_means = real_k_block_means.mean(axis=1)
|
real_k_block_means = real_k_block_means.mean(axis=1)
|
||||||
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
|
compute_k_block_means = self.cache_k_block_means[i, 0 : real_k_block_means.shape[0]]
|
||||||
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
|
assert (compute_k_block_means - real_k_block_means).abs().max() < 0.003
|
||||||
|
|
||||||
print("[consistency]Moba attention: split_qkv_rope matches.")
|
print("[consistency]plas attention: split_qkv_rope matches.")
|
||||||
|
|
||||||
def compare_mlp_einsum(self, k_gate_weight):
|
def compare_mlp_einsum(self, k_gate_weight):
|
||||||
for i in range(self.max_num_seqs):
|
for i in range(self.max_num_seqs):
|
||||||
k_padding = paddle.zeros(
|
k_padding = paddle.zeros(
|
||||||
[
|
[
|
||||||
(self.seq_len + self.moba_block_size - 1) // self.moba_block_size * self.moba_block_size,
|
(self.seq_len + self.plas_block_size - 1) // self.plas_block_size * self.plas_block_size,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
],
|
],
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
)
|
)
|
||||||
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
k_padding[0 : self.seq_len] = self.k_input[i * self.seq_len : (i + 1) * self.seq_len]
|
||||||
k_padding = k_padding.reshape([-1, self.moba_block_size, self.num_kv_heads, self.head_dim])
|
k_padding = k_padding.reshape([-1, self.plas_block_size, self.num_kv_heads, self.head_dim])
|
||||||
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
|
real_result = paddle.einsum("nbhd,hbd->nhd", k_padding, self.attn_gate_weight)
|
||||||
compute_result = k_gate_weight[i][0 : real_result.shape[0]]
|
compute_result = k_gate_weight[i][0 : real_result.shape[0]]
|
||||||
|
|
||||||
assert (real_result - compute_result).abs().max() < 0.5
|
assert (real_result - compute_result).abs().max() < 0.5
|
||||||
|
|
||||||
print("[consistency]Moba attention: MLP einsum matches.")
|
print("[consistency]plas attention: MLP einsum matches.")
|
||||||
|
|
||||||
def compare_qk_gemm(self, qk_gate_weight):
|
def compare_qk_gemm(self, qk_gate_weight):
|
||||||
for i in range(self.max_num_seqs):
|
for i in range(self.max_num_seqs):
|
||||||
@@ -170,10 +170,10 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
|
conpute_result = qk_gate_weight[i * self.seq_len : (i + 1) * self.seq_len, :, 0 : self.num_blocks]
|
||||||
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
|
assert (qk_gemm_out - conpute_result).abs().max() < 1e-4
|
||||||
|
|
||||||
print("[consistency]Moba attention: qk_gemm matches.")
|
print("[consistency]plas attention: qk_gemm matches.")
|
||||||
|
|
||||||
def compare_qk_gate_topk(self, qk_gate_topk_idx):
|
def compare_qk_gate_topk(self, qk_gate_topk_idx):
|
||||||
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size
|
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
|
||||||
for i in range(self.max_num_seqs):
|
for i in range(self.max_num_seqs):
|
||||||
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
|
qk_gate_topk_idx_batch = qk_gate_topk_idx[i * self.num_blocks : (i + 1) * self.num_blocks]
|
||||||
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
|
qk_gate_topk_idx_batch_no_sparse = qk_gate_topk_idx_batch[0 : limit_topk - 1]
|
||||||
@@ -191,40 +191,40 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
|
- paddle.ones(qk_gate_topk_idx_batch_sparse.shape, qk_gate_topk_idx_batch_sparse.dtype)
|
||||||
* self.sparse_step
|
* self.sparse_step
|
||||||
).abs().max() < 1e-6
|
).abs().max() < 1e-6
|
||||||
print("[consistency]Moba attention: qk_gate_topk matches.")
|
print("[consistency]plas attention: qk_gate_topk matches.")
|
||||||
|
|
||||||
def compare_attn(self, attn_out, qk_gate_topk_idx):
|
def compare_attn(self, attn_out, qk_gate_topk_idx):
|
||||||
x = (
|
x = (
|
||||||
paddle.tensor.triu(paddle.ones([self.moba_block_size, self.moba_block_size], dtype="bfloat16"), 1)
|
paddle.tensor.triu(paddle.ones([self.plas_block_size, self.plas_block_size], dtype="bfloat16"), 1)
|
||||||
* -1000000
|
* -1000000
|
||||||
)
|
)
|
||||||
limit_topk = self.moba_use_encoder_seq_limit // self.moba_block_size
|
limit_topk = self.plas_use_encoder_seq_limit // self.plas_block_size
|
||||||
for i in range(self.max_num_seqs):
|
for i in range(self.max_num_seqs):
|
||||||
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
q_input = self.q_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
||||||
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
k_input = self.k_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
||||||
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
v_input = self.v_input[i * self.seq_len : (i + 1) * self.seq_len].unsqueeze(axis=0)
|
||||||
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
|
mask = paddle.tensor.triu(paddle.ones([self.seq_len, self.seq_len], dtype="bfloat16"), 1) * -1000000
|
||||||
mask[self.moba_use_encoder_seq_limit - self.moba_block_size :] = -1000000
|
mask[self.plas_use_encoder_seq_limit - self.plas_block_size :] = -1000000
|
||||||
for i in range(limit_topk - 1, self.num_blocks):
|
for i in range(limit_topk - 1, self.num_blocks):
|
||||||
n_block = i
|
n_block = i
|
||||||
mask[
|
mask[
|
||||||
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size,
|
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
|
||||||
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size,
|
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
|
||||||
] = x
|
] = x
|
||||||
idx = 0
|
idx = 0
|
||||||
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
while n_block >= 0:
|
while n_block >= 0:
|
||||||
mask[
|
mask[
|
||||||
i * self.moba_block_size : i * self.moba_block_size + self.moba_block_size,
|
i * self.plas_block_size : i * self.plas_block_size + self.plas_block_size,
|
||||||
n_block * self.moba_block_size : n_block * self.moba_block_size + self.moba_block_size,
|
n_block * self.plas_block_size : n_block * self.plas_block_size + self.plas_block_size,
|
||||||
] = 0
|
] = 0
|
||||||
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
n_block -= int(qk_gate_topk_idx[i, 0, idx])
|
||||||
idx += 1
|
idx += 1
|
||||||
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
|
naive_attn_out = naive_attn(q_input, k_input, v_input, mask).squeeze(axis=0).transpose([1, 0, 2])
|
||||||
assert (attn_out - naive_attn_out).abs().max() < 0.016
|
assert (attn_out - naive_attn_out).abs().max() < 0.016
|
||||||
|
|
||||||
def test_moba_attention(self):
|
def test_plas_attention(self):
|
||||||
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
|
qkv_out = paddle.randn([self.tokens, self.num_heads + 2 * self.num_kv_heads, self.head_dim], dtype="bfloat16")
|
||||||
|
|
||||||
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
|
seq_len_encoder = paddle.to_tensor([self.seq_len] * self.max_num_seqs, dtype="int32")
|
||||||
@@ -255,7 +255,7 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.moba_max_seq_length,
|
self.plas_max_seq_length,
|
||||||
self.seq_len,
|
self.seq_len,
|
||||||
self.seq_len,
|
self.seq_len,
|
||||||
"none",
|
"none",
|
||||||
@@ -307,9 +307,9 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
self.seq_len,
|
self.seq_len,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.moba_encoder_top_k_left,
|
self.plas_encoder_top_k_left,
|
||||||
self.moba_encoder_top_k_right,
|
self.plas_encoder_top_k_right,
|
||||||
self.moba_use_encoder_seq_limit,
|
self.plas_use_encoder_seq_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_qk_gate_topk(qk_gate_topk_idx)
|
self.compare_qk_gate_topk(qk_gate_topk_idx)
|
||||||
@@ -332,7 +332,7 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.moba_max_seq_length,
|
self.plas_max_seq_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compare_attn(attn_out, qk_gate_topk_idx)
|
self.compare_attn(attn_out, qk_gate_topk_idx)
|
||||||
@@ -340,18 +340,18 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
def test_server(self):
|
def test_server(self):
|
||||||
if get_cur_cu_seq_len_k is None:
|
if get_cur_cu_seq_len_k is None:
|
||||||
return
|
return
|
||||||
os.environ["FD_ATTENTION_BACKEND"] = "MOBA_ATTN"
|
os.environ["FD_ATTENTION_BACKEND"] = "PLAS_ATTN"
|
||||||
base_path = os.getenv("MODEL_PATH")
|
base_path = os.getenv("MODEL_PATH")
|
||||||
if base_path:
|
if base_path:
|
||||||
model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle")
|
model_path = os.path.join(base_path, "./ernie-4_5-21b-a3b-bf16-paddle")
|
||||||
else:
|
else:
|
||||||
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
|
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
|
||||||
|
|
||||||
moba_attention_config = {
|
plas_attention_config = {
|
||||||
"moba_encoder_top_k_left": 50,
|
"plas_encoder_top_k_left": 50,
|
||||||
"moba_encoder_top_k_right": 60,
|
"plas_encoder_top_k_right": 60,
|
||||||
"moba_decoder_top_k_left": 100,
|
"plas_decoder_top_k_left": 100,
|
||||||
"moba_decoder_top_k_right": 120,
|
"plas_decoder_top_k_right": 120,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
@@ -365,7 +365,7 @@ class TestMobaAttention(unittest.TestCase):
|
|||||||
quantization="wint4",
|
quantization="wint4",
|
||||||
enable_chunked_prefill=True,
|
enable_chunked_prefill=True,
|
||||||
max_num_batched_tokens=8192,
|
max_num_batched_tokens=8192,
|
||||||
moba_attention_config=moba_attention_config,
|
plas_attention_config=plas_attention_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = ["Hello world!"]
|
prompts = ["Hello world!"]
|
@@ -65,7 +65,7 @@ class TestAttentionInitWeight(unittest.TestCase):
|
|||||||
self.fd_config.parallel_config = self.parallel_config
|
self.fd_config.parallel_config = self.parallel_config
|
||||||
self.fd_config.cache_config = self.cache_config
|
self.fd_config.cache_config = self.cache_config
|
||||||
self.fd_config.quant_config = None
|
self.fd_config.quant_config = None
|
||||||
self.fd_config.moba_attention_config = None
|
self.fd_config.plas_attention_config = None
|
||||||
|
|
||||||
def test_init_weight_without_quantization(self):
|
def test_init_weight_without_quantization(self):
|
||||||
"""Test init_weight without quantization."""
|
"""Test init_weight without quantization."""
|
||||||
@@ -141,7 +141,7 @@ class TestAttentionWeightLoader(unittest.TestCase):
|
|||||||
self.fd_config.model_config = self.model_config
|
self.fd_config.model_config = self.model_config
|
||||||
self.fd_config.parallel_config = self.parallel_config
|
self.fd_config.parallel_config = self.parallel_config
|
||||||
self.fd_config.cache_config = self.cache_config
|
self.fd_config.cache_config = self.cache_config
|
||||||
self.fd_config.moba_attention_config = None
|
self.fd_config.plas_attention_config = None
|
||||||
|
|
||||||
# Create mock quant method
|
# Create mock quant method
|
||||||
self.mock_quant_method = MockQuantMethod()
|
self.mock_quant_method = MockQuantMethod()
|
||||||
|
Reference in New Issue
Block a user