[Feature] support flash_mask_attention backend (#5134)

* [Feature] suppert flash_mask_attention backend

* fix unittest

* clean code
This commit is contained in:
lizhenyun01
2025-11-28 10:12:16 +08:00
committed by GitHub
parent b935101008
commit aba4fc657f
13 changed files with 542 additions and 69 deletions

View File

@@ -18,6 +18,7 @@ from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .block_multihead_attn_backend import BlockAttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .flash_mask_attn_backend import FlashMaskAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import PlasAttentionBackend
@@ -36,4 +37,5 @@ __all__ = [
"BlockAttentionBackend",
"Attention",
"PlasAttentionBackend",
"FlashMaskAttentionBackend",
]

View File

@@ -0,0 +1,316 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.ops import (
flash_mask_attention,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
init_kv_signal_per_query,
init_signal_layerwise,
open_shm_and_get_meta_signal,
pre_cache_len_concat,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
import os
@dataclass
class FlashMaskAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
class FlashMaskAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashMaskAttentionMetadata
flash_attn_func: callable = None
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashMaskAttentionMetadata = None
self.max_seq_len = fd_config.model_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
return key_cache_shape, value_cache_shape
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashMaskAttentionMetadata()
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
)
if metadata.max_len_tensor_cpu[1] > 0:
res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "rms_norm_eps", 1e-6),
getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
)
flash_mask_attention(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,
forward_meta.attn_mask_offsets,
self.num_heads,
self.kv_num_heads,
self.head_dim,
self.max_seq_len,
q.shape[0],
k.shape[0],
)
return res_encoder
else:
raise NotImplementedError("FlashMaskAttentionBackend is not supported for decode.")

View File

@@ -15,6 +15,7 @@
"""
from .append_attention import append_attention, append_attention_with_output
from .flash_mask_attention import flash_mask_attention
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -31,4 +32,5 @@ __all__ = [
"gqa_rope_write_cache",
"pre_cache_len_concat",
"init_kv_signal_per_query",
"flash_mask_attention",
]

View File

@@ -0,0 +1,60 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
def flash_mask_attention(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
attn_out: paddle.Tensor,
attn_mask_offsets: Optional[paddle.Tensor] = None,
num_heads: int = 0,
kv_num_heads: int = 0,
head_dim: int = 128,
max_seq_len: int = 0,
q_token_num: int = 0,
kv_token_num: int = 0,
):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
flash_mask_attention(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_encoder,
attn_out,
attn_mask_offsets,
num_heads,
kv_num_heads,
head_dim,
max_seq_len,
q_token_num,
kv_token_num,
)
else:
raise NotImplementedError

View File

@@ -39,6 +39,8 @@ def gqa_rope_write_cache(
cache_batch_ids: paddle.Tensor,
cache_tile_ids_per_batch: paddle.Tensor,
cache_num_blocks: paddle.Tensor,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
cache_k_quant_scales: Optional[paddle.Tensor] = None,
cache_v_quant_scales: Optional[paddle.Tensor] = None,
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
@@ -48,6 +50,7 @@ def gqa_rope_write_cache(
kv_signal_data: Optional[paddle.Tensor] = None,
kv_token_num: int = 1,
max_seq_len: int = 0,
rms_norm_eps: float = 1e-6,
cache_quant_type: str = "none",
rope_3d: bool = False,
):
@@ -72,6 +75,8 @@ def gqa_rope_write_cache(
cache_batch_ids,
cache_tile_ids_per_batch,
cache_num_blocks,
q_norm_weight,
k_norm_weight,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
@@ -81,6 +86,7 @@ def gqa_rope_write_cache(
kv_signal_data,
kv_token_num,
max_seq_len,
rms_norm_eps,
cache_quant_type,
rope_3d,
)