[Cherry-Pick][Feature] support flash_mask_attention backend(#5134) (#5256)

* [Feature] suppert flash_mask_attention backend

* fix unittest

* clean code
This commit is contained in:
lizhenyun01
2025-11-28 10:13:00 +08:00
committed by GitHub
parent 9b0c65ba57
commit fd1313cdb4
13 changed files with 542 additions and 69 deletions

View File

@@ -24,6 +24,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const float *q_norm_weight,
const float *k_norm_weight,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
@@ -38,37 +40,46 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_num_head,
const int seq_len,
const int last_dim,
const bool rope_3d) {
const bool rope_3d,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
using LoadFloat = AlignedVector<float, VecSize>;
LoadT src_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
LoadFloat tmp_vec;
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int offset = (q_num_head + kv_num_head * 2) * last_dim;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
const int offset =
(q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v
const int all_head_num = elem_cnt / last_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
gloabl_hi += all_warp_num) {
int64_t linear_index =
gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index
const int token_idx =
linear_index / offset; // token id(第几个token,不分qkv)
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
if (seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
(token_idx - cu_seqlens_q[ori_bi]) +
seq_lens_decoder
[ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
const int64_t emb_idx =
ori_seq_id * half_lastdim + h_bias / 2; // embedding的id
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
int64_t base_split_idx;
T *out_p = nullptr;
if (hi < q_num_head) {
@@ -84,21 +95,67 @@ __global__ void GQAVariableLengthRotarySplitKernel(
base_split_idx = kv_write_idx * kv_num_head * last_dim +
(hi - q_num_head - kv_num_head) * last_dim + h_bias;
}
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
// TODO check this correct or not
int64_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
if (q_norm_weight && k_norm_weight) {
if (hi < q_num_head + kv_num_head) { // only q and k need rope
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
tmp_vec[2 * i] = tmp1;
tmp_vec[2 * i + 1] = tmp2;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
}
}
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); // 单个head的标准差
if (hi < q_num_head + kv_num_head) { // only q and k need norm
float row_variance = max(warp_m2 / last_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < q_num_head) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
&q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
src_vec[i] =
static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
}
} else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize],
&k_norm_vec);
for (int i = 0; i < VecSize; i++) {
src_vec[i] =
static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
}
}
}
} else {
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec[2 * i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
src_vec[2 * i + 1] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
}
}
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
@@ -114,6 +171,8 @@ void gqa_rotary_qk_split_variable(
T *v,
const T *qkv_input,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const float *q_norm_weight,
const float *k_norm_weight,
const int *batch_id_per_token,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
@@ -126,24 +185,31 @@ void gqa_rotary_qk_split_variable(
const int input_output_len,
const int dim_head,
const bool rope_3d,
const float rms_norm_eps,
const cudaStream_t &stream) {
assert(dim_head == 128 && "dim_head must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_size(kWarpSize, blocksize / kWarpSize);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
grid_size,
blocksize,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
q_norm_weight,
k_norm_weight,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
@@ -158,7 +224,8 @@ void gqa_rotary_qk_split_variable(
kv_num_heads,
seq_len,
dim_head,
rope_3d);
rope_3d,
rms_norm_eps);
}
template <typename T,
@@ -1054,6 +1121,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor &cache_batch_ids,
const paddle::Tensor &cache_tile_ids,
const paddle::Tensor &cache_num_blocks,
const paddle::optional<paddle::Tensor> &q_norm_weight,
const paddle::optional<paddle::Tensor> &k_norm_weight,
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
@@ -1063,6 +1132,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const float rms_norm_eps,
const std::string &cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
@@ -1113,6 +1183,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
@@ -1125,6 +1197,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
rms_norm_eps,
stream);
if (token_num < kv_token_num) {
@@ -1259,6 +1332,8 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
"cache_batch_ids",
"cache_tile_ids_per_batch",
"cache_num_blocks",
paddle::Optional("q_norm_weight"),
paddle::Optional("k_norm_weight"),
paddle::Optional("cache_k_quant_scales"),
paddle::Optional("cache_v_quant_scales"),
paddle::Optional("cache_k_dequant_scales"),
@@ -1271,5 +1346,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
{"value_cache", "value_cache_out"}})
.Attrs({"kv_token_num: int",
"max_seq_len: int",
"cache_quant_type: std::string"})
"rms_norm_eps: float",
"cache_quant_type: std::string",
"rope_3d: bool"})
.SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel));

View File

@@ -178,6 +178,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& cache_batch_ids,
const paddle::Tensor& cache_tile_ids,
const paddle::Tensor& cache_num_blocks,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
@@ -187,6 +189,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const float rms_norm_eps,
const std::string& cache_quant_type,
const bool rope_3d);

View File

@@ -46,12 +46,11 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
const int q_token_num,
const int k_token_num) {
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
const int batch_size = seq_len_encoder.dims()[0];
const int batch_size = cu_seq_k.dims()[0] - 1;
Flash_mask_params params;
memset(&params, 0, sizeof(Flash_mask_params));
@@ -63,8 +62,8 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
params.head_num = head_num;
params.kv_head_num = kv_head_num;
params.max_seq_len_q = max_enc_len_this_time;
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
params.q_token_num = q_token_num;
params.k_token_num = k_token_num;
params.batch_size = batch_size;
params.gqa_group_size = head_num / kv_head_num;
constexpr float kLog2e = 1.4426950408889634074;
@@ -132,8 +131,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
const int kv_head_num,
const int head_dim,
const int max_seq_len,
const int max_enc_len_this_time,
const int max_dec_len_this_time) {
const int q_token_num,
const int k_token_num) {
if (q_input.dtype() == paddle::DataType::FLOAT16) {
using T = phi::dtype::float16;
DispatchFlashAttentionMask<T>(q_input,
@@ -148,8 +147,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time);
q_token_num,
k_token_num);
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
using T = phi::dtype::bfloat16;
DispatchFlashAttentionMask<T>(q_input,
@@ -164,12 +163,12 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
kv_head_num,
head_dim,
max_seq_len,
max_enc_len_this_time,
max_dec_len_this_time);
q_token_num,
k_token_num);
}
}
PD_BUILD_STATIC_OP(flash_attention_mask)
PD_BUILD_STATIC_OP(flash_mask_attention)
.Inputs({"q_input",
"k_input",
"v_input",
@@ -182,8 +181,8 @@ PD_BUILD_STATIC_OP(flash_attention_mask)
"kv_head_num: int",
"head_dim: int",
"max_seq_len: int",
"max_enc_len_this_time: int",
"max_dec_len_this_time: int"})
"q_token_num: int",
"k_token_num: int"})
.Outputs({"out"})
.SetInplaceMap({{"attn_out", "out"}})
.SetKernelFn(PD_KERNEL(FlashAttentionMask));

View File

@@ -59,19 +59,26 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
auto &shared_storage =
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
__align__(16) __shared__ int mask[kBlockM];
__align__(16) __shared__ int mask_end[kBlockM];
__align__(16) __shared__ int mask_start[kBlockM];
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int bidb = blockIdx.z;
if (data_params.seq_len_encoder[bidb] <= 0) {
return;
}
if constexpr (NeedMask) {
const int *mask_this_batch =
data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
const int2 *mask_this_batch =
reinterpret_cast<int2 *>(data_params.mask) +
(data_params.cu_seq_q[bidb] + m_block * kBlockM);
for (int i = threadIdx.x; i < kBlockM;
i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
mask[i] = mask_this_batch[i];
int2 mask_value = mask_this_batch[i];
mask_start[i] = mask_value.x;
mask_end[i] = mask_value.y;
}
}
@@ -119,7 +126,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
const int n_block_max =
NeedMask
? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN)
? cute::ceil_div(mask_end[min(kBlockM - 1, real_seq - 1)], kBlockN)
: min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q,
kBlockN),
cute::ceil_div(seq_len_k, kBlockN));
@@ -170,7 +177,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
smem_pipe_read_v,
tOrO,
softmax,
mask,
mask_end,
n_block_max,
threadIdx.x - NumCopyThreads,
m_block,
@@ -207,18 +214,15 @@ void run_flash_mask(Flash_mask_params &params, cudaStream_t stream) {
typename CollectiveMainloop::Params mainloop_params =
CollectiveMainloop::to_underlying_arguments(
{static_cast<Element const *>(params.q_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_q * params.batch_size,
params.head_num),
get_gmem_layout<kHeadDim>(params.q_token_num, params.head_num),
static_cast<Element const *>(params.k_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
params.kv_head_num),
get_gmem_layout<kHeadDim>(params.k_token_num, params.kv_head_num),
static_cast<Element const *>(params.v_ptr),
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
params.kv_head_num),
get_gmem_layout<kHeadDim>(params.k_token_num, params.kv_head_num),
params.scale_softmax_log2});
int num_blocks_m =
cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
cutlass::ceil_div(params.q_token_num, Kernel_traits::kBlockM);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) *
size<0>(ClusterShape{});

View File

@@ -35,8 +35,8 @@ struct Flash_mask_params {
int *seq_len_encoder;
int head_num;
int kv_head_num;
int max_seq_len_q;
int max_seq_len_k;
int q_token_num;
int k_token_num;
int batch_size;
int gqa_group_size;
float scale_softmax_log2;

View File

@@ -18,6 +18,7 @@ from .attention_selecter import get_attention_backend
from .base_attention_backend import AttentionBackend
from .block_multihead_attn_backend import BlockAttentionBackend
from .flash_attn_backend import FlashAttentionBackend
from .flash_mask_attn_backend import FlashMaskAttentionBackend
from .iluvatar_attn_backend import IluvatarAttnBackend
from .mla_attention_backend import MLAAttentionBackend
from .moba_attention_backend import PlasAttentionBackend
@@ -36,4 +37,5 @@ __all__ = [
"BlockAttentionBackend",
"Attention",
"PlasAttentionBackend",
"FlashMaskAttentionBackend",
]

View File

@@ -0,0 +1,316 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional
import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend,
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.ops import (
flash_mask_attention,
get_block_shape_and_split_kv_block,
gqa_rope_write_cache,
init_kv_signal_per_query,
init_signal_layerwise,
open_shm_and_get_meta_signal,
pre_cache_len_concat,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
if TYPE_CHECKING:
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.platforms import current_platform
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
else:
merge_prefill_decode_output = None
import os
@dataclass
class FlashMaskAttentionMetadata(AttentionMetadata):
"""
FlashAttentionMetadata
"""
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
max_seqlen_q: int = 0
max_seqlen_k: int = 0
pre_cache_batch_ids = None
pre_cache_tile_ids_per_batch = None
pre_cache_num_blocks_cpu = None
kv_token_num_cpu = None
# pd_disaggregation
kv_signal_metadata: Optional[paddle.Tensor] = None
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
_fuse_kernel_compute_dtype: str = "bf16"
_dtype: paddle.dtype = paddle.bfloat16
max_len_tensor_cpu: paddle.Tensor = None
max_len_tensor_cpu_decoder: paddle.Tensor = None
class FlashMaskAttentionBackend(AttentionBackend):
"""
FlashAttentionBackend backend implementation
"""
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: FlashMaskAttentionMetadata
flash_attn_func: callable = None
def __init__(
self,
fd_config: FDConfig,
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
FlashAttentionBackend __init__
"""
super().__init__()
self.attention_metadata: FlashMaskAttentionMetadata = None
self.max_seq_len = fd_config.model_config.max_model_len
self.causal = getattr(fd_config.model_config, "causal", True)
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
self.start_layer_index: int = fd_config.model_config.start_layer_index
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
self.rank, self.device_id = init_rank_and_device_id(fd_config)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
fd_config.model_config, "use_3d_rope", False
)
if fd_config.speculative_config.model_type != "main":
self.rope_3d = False
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
self.zero_seq_enc_lens_for_decode = paddle.zeros(
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
)
def get_attntion_meta(self):
"""get_attntion_meta"""
return self.attention_metadata
def get_kv_cache_shape(
self,
max_num_blocks: int,
kv_cache_quant_type: str = None,
):
"""
Calculate kv cache shape
"""
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
key_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
value_cache_shape = [
max_num_blocks,
self.kv_num_heads,
self.block_size,
self.head_dim // 2,
]
return key_cache_shape, value_cache_shape
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashMaskAttentionMetadata()
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_device,
forward_meta.decoder_chunk_size_device,
forward_meta.max_len_tensor_cpu,
forward_meta.encoder_batch_ids,
forward_meta.encoder_tile_ids_per_batch,
forward_meta.encoder_num_blocks_x_cpu,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
)
(
metadata.cu_seqlens_k,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
metadata.kv_token_num_cpu,
) = pre_cache_len_concat(
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.max_len_tensor_cpu[2],
self.block_size,
)
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
if self.pd_disaggregation_mode == "per_chunk":
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
init_kv_signal_per_query(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_decoder,
self.rank,
self.num_layers + self.num_layers_draft_model,
)
elif self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
self.rank, int(self.device_id), self.keep_pd_step_flag
)
if metadata._dtype == "bfloat16":
metadata._fuse_kernel_compute_dtype = "bf16"
elif metadata._dtype == "float16":
metadata._fuse_kernel_compute_dtype = "fp16"
elif metadata._dtype == "float32":
metadata._fuse_kernel_compute_dtype = "fp32"
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
metadata.max_len_tensor_cpu_decoder[1] = 0
self.attention_metadata = metadata
def forward_mixed(
self,
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
qkv: paddle.Tensor,
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
):
metadata = self.attention_metadata
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
)
if metadata.max_len_tensor_cpu[1] > 0:
res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype)
q, k, v, _ = gqa_rope_write_cache(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
metadata.rotary_embs,
forward_meta.seq_lens_this_time,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.batch_id_per_token,
metadata.block_tables,
forward_meta.kv_batch_ids,
forward_meta.kv_tile_ids_per_batch,
forward_meta.kv_num_blocks_x_cpu,
metadata.pre_cache_batch_ids,
metadata.pre_cache_tile_ids_per_batch,
metadata.pre_cache_num_blocks_cpu,
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
getattr(layer, "cache_v_zp", None),
metadata.kv_signal_data_list[layer.layer_id],
metadata.kv_token_num_cpu[0].item(),
self.max_seq_len,
getattr(layer, "rms_norm_eps", 1e-6),
getattr(layer, "cache_quant_type_str", "none"),
self.rope_3d,
)
flash_mask_attention(
q,
k,
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
forward_meta.seq_lens_encoder,
res_encoder,
forward_meta.attn_mask_offsets,
self.num_heads,
self.kv_num_heads,
self.head_dim,
self.max_seq_len,
q.shape[0],
k.shape[0],
)
return res_encoder
else:
raise NotImplementedError("FlashMaskAttentionBackend is not supported for decode.")

View File

@@ -15,6 +15,7 @@
"""
from .append_attention import append_attention, append_attention_with_output
from .flash_mask_attention import flash_mask_attention
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
from .gqa_rope_write_cache import gqa_rope_write_cache
from .init_kv_signal_per_query import init_kv_signal_per_query
@@ -31,4 +32,5 @@ __all__ = [
"gqa_rope_write_cache",
"pre_cache_len_concat",
"init_kv_signal_per_query",
"flash_mask_attention",
]

View File

@@ -0,0 +1,60 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from typing import Optional
import paddle
from fastdeploy.platforms import current_platform
def flash_mask_attention(
q: paddle.Tensor,
k: paddle.Tensor,
v: paddle.Tensor,
cu_seqlens_q: paddle.Tensor,
cu_seqlens_k: paddle.Tensor,
seq_lens_encoder: paddle.Tensor,
attn_out: paddle.Tensor,
attn_mask_offsets: Optional[paddle.Tensor] = None,
num_heads: int = 0,
kv_num_heads: int = 0,
head_dim: int = 128,
max_seq_len: int = 0,
q_token_num: int = 0,
kv_token_num: int = 0,
):
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
flash_mask_attention(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seq_lens_encoder,
attn_out,
attn_mask_offsets,
num_heads,
kv_num_heads,
head_dim,
max_seq_len,
q_token_num,
kv_token_num,
)
else:
raise NotImplementedError

View File

@@ -39,6 +39,8 @@ def gqa_rope_write_cache(
cache_batch_ids: paddle.Tensor,
cache_tile_ids_per_batch: paddle.Tensor,
cache_num_blocks: paddle.Tensor,
q_norm_weight: Optional[paddle.Tensor] = None,
k_norm_weight: Optional[paddle.Tensor] = None,
cache_k_quant_scales: Optional[paddle.Tensor] = None,
cache_v_quant_scales: Optional[paddle.Tensor] = None,
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
@@ -48,6 +50,7 @@ def gqa_rope_write_cache(
kv_signal_data: Optional[paddle.Tensor] = None,
kv_token_num: int = 1,
max_seq_len: int = 0,
rms_norm_eps: float = 1e-6,
cache_quant_type: str = "none",
rope_3d: bool = False,
):
@@ -72,6 +75,8 @@ def gqa_rope_write_cache(
cache_batch_ids,
cache_tile_ids_per_batch,
cache_num_blocks,
q_norm_weight,
k_norm_weight,
cache_k_quant_scales,
cache_v_quant_scales,
cache_k_dequant_scales,
@@ -81,6 +86,7 @@ def gqa_rope_write_cache(
kv_signal_data,
kv_token_num,
max_seq_len,
rms_norm_eps,
cache_quant_type,
rope_3d,
)

View File

@@ -28,6 +28,7 @@ class _Backend(enum.Enum):
BLOCK_ATTN = enum.auto()
PLAS_ATTN = enum.auto()
HPU_ATTN = enum.auto()
FLASH_MASK_ATTN = enum.auto()
class Platform:

View File

@@ -67,6 +67,9 @@ class CUDAPlatform(Platform):
elif selected_backend == _Backend.PLAS_ATTN:
logger.info("Using PLAS ATTN backend.")
return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend"
elif selected_backend == _Backend.FLASH_MASK_ATTN:
logger.info("Using FLASH MASK ATTN backend.")
return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend"
else:
raise ValueError(
"Invalid attention backend you specified.\n"

View File

@@ -19,7 +19,7 @@ import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import flash_attention_mask
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
class TestFlashMaskAttention(unittest.TestCase):
@@ -27,7 +27,7 @@ class TestFlashMaskAttention(unittest.TestCase):
self.bsz = 1
self.num_head = 8
self.num_kv_head = 1
self.q_seq_len = 1024
self.q_seq_len = 888
self.k_seq_len = 1024
self.head_dim = 128
np.random.seed(self.q_seq_len)
@@ -71,7 +71,7 @@ class TestFlashMaskAttention(unittest.TestCase):
v_input_pad[0 : v_input.shape[0]] = v_input
mask = paddle.to_tensor(mask).astype("int32")
flash_attention_mask(
flash_mask_attention(
q_input,
k_input,
v_input_pad,
@@ -88,7 +88,7 @@ class TestFlashMaskAttention(unittest.TestCase):
int(k_input.shape[0]),
)
def test_flash_attention_mask(self):
def test_flash_mask_attention(self):
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
k_input = np.random.normal(
0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim)
@@ -104,8 +104,8 @@ class TestFlashMaskAttention(unittest.TestCase):
mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16")
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask).transpose([0, 2, 1, 3])
paddle_attn_out = paddle.zeros(q_input.shape, dtype="bfloat16")
self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask)
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())