diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 4f847e8de..53b7e6266 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -24,6 +24,8 @@ __global__ void GQAVariableLengthRotarySplitKernel( const T *qkv, const float *cos_emb, const float *sin_emb, + const float *q_norm_weight, + const float *k_norm_weight, const int *batch_id_per_token, const int *cu_seqlens_q, const int *seq_lens, @@ -38,37 +40,46 @@ __global__ void GQAVariableLengthRotarySplitKernel( const int kv_num_head, const int seq_len, const int last_dim, - const bool rope_3d) { + const bool rope_3d, + const float rms_norm_eps) { using LoadT = AlignedVector; constexpr int HalfVecSize = VecSize / 2; using LoadEmbT = AlignedVector; + using LoadFloat = AlignedVector; LoadT src_vec; LoadEmbT cos_emb_vec; LoadEmbT sin_emb_vec; - int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + LoadFloat tmp_vec; + LoadFloat q_norm_vec, k_norm_vec; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; const int half_lastdim = last_dim / 2; - const int offset = (q_num_head + kv_num_head * 2) * last_dim; - for (int64_t linear_index = global_thread_idx * VecSize, - step = gridDim.x * blockDim.x * VecSize; - linear_index < elem_cnt; - linear_index += step) { - const int token_idx = linear_index / offset; - const int ori_bi = batch_id_per_token[token_idx]; + const int offset = + (q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v + const int all_head_num = elem_cnt / last_dim; + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; + gloabl_hi += all_warp_num) { + int64_t linear_index = + gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index + const int token_idx = + linear_index / offset; // token id(第几个token,不分qkv) + const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch if (seq_lens[ori_bi] == 0) continue; const int bias = linear_index % offset; const int hi = bias / last_dim; const int h_bias = bias % last_dim; const int ori_seq_id = - (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; - const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; - - const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2; - int64_t new_emb_idx = - rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + (token_idx - cu_seqlens_q[ori_bi]) + + seq_lens_decoder + [ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) + const int64_t emb_idx = + ori_seq_id * half_lastdim + h_bias / 2; // embedding的id const int64_t base_idx = token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim + h_bias; + Load(&qkv[base_idx], &src_vec); + const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; int64_t base_split_idx; T *out_p = nullptr; if (hi < q_num_head) { @@ -84,21 +95,67 @@ __global__ void GQAVariableLengthRotarySplitKernel( base_split_idx = kv_write_idx * kv_num_head * last_dim + (hi - q_num_head - kv_num_head) * last_dim + h_bias; } - Load(&qkv[base_idx], &src_vec); - // do rope - if (hi < q_num_head + kv_num_head) { - Load(&cos_emb[new_emb_idx], &cos_emb_vec); - Load(&sin_emb[new_emb_idx], &sin_emb_vec); + + // TODO check this correct or not + int64_t new_emb_idx = + rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx; + float thread_m2 = 0.0f; + float warp_m2 = 0.0f; + + if (q_norm_weight && k_norm_weight) { + if (hi < q_num_head + kv_num_head) { // only q and k need rope + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); #pragma unroll - for (int i = 0; i < HalfVecSize; i++) { - const float input_left = static_cast(src_vec[2 * i]); - const float input_right = static_cast(src_vec[2 * i + 1]); - const float cos_tmp = cos_emb_vec[i]; - const float sin_tmp = sin_emb_vec[i]; - src_vec[2 * i] = - static_cast(input_left * cos_tmp - input_right * sin_tmp); - src_vec[2 * i + 1] = - static_cast(input_right * cos_tmp + input_left * sin_tmp); + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + float tmp1 = input_left * cos_tmp - input_right * sin_tmp; + float tmp2 = input_right * cos_tmp + input_left * sin_tmp; + tmp_vec[2 * i] = tmp1; + tmp_vec[2 * i + 1] = tmp2; + thread_m2 += tmp1 * tmp1 + tmp2 * tmp2; + } + } + WelfordWarpAllReduce(thread_m2, &warp_m2); // 单个head的标准差 + + if (hi < q_num_head + kv_num_head) { // only q and k need norm + float row_variance = max(warp_m2 / last_dim, 0.0f); + float row_inv_var = Rsqrt(row_variance + rms_norm_eps); + if (hi < q_num_head) { + Load(&q_norm_weight[threadIdx.x * VecSize], + &q_norm_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + src_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); + } + } else { + Load(&k_norm_weight[threadIdx.x * VecSize], + &k_norm_vec); + for (int i = 0; i < VecSize; i++) { + src_vec[i] = + static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); + } + } + } + } else { + if (hi < q_num_head + kv_num_head) { + Load(&cos_emb[new_emb_idx], &cos_emb_vec); + Load(&sin_emb[new_emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < HalfVecSize; i++) { + const float input_left = static_cast(src_vec[2 * i]); + const float input_right = static_cast(src_vec[2 * i + 1]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + src_vec[2 * i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + src_vec[2 * i + 1] = + static_cast(input_right * cos_tmp + input_left * sin_tmp); + } } } Store(src_vec, &qkv_out[base_idx]); @@ -114,6 +171,8 @@ void gqa_rotary_qk_split_variable( T *v, const T *qkv_input, const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2] + const float *q_norm_weight, + const float *k_norm_weight, const int *batch_id_per_token, const int *seq_lens_encoder, const int *seq_lens_decoder, @@ -126,24 +185,31 @@ void gqa_rotary_qk_split_variable( const int input_output_len, const int dim_head, const bool rope_3d, + const float rms_norm_eps, const cudaStream_t &stream) { + assert(dim_head == 128 && "dim_head must be 128"); int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head; - constexpr int PackSize = 16 / sizeof(T); + + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; const int pack_num = elem_nums / PackSize; const int blocksize = 128; int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_size(kWarpSize, blocksize / kWarpSize); const float *cos_emb = rotary_emb; const float *sin_emb = rotary_emb + input_output_len * dim_head / 2; launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel, grid_size, - blocksize, + block_size, 0, stream, qkv_input, cos_emb, sin_emb, + q_norm_weight, + k_norm_weight, batch_id_per_token, cu_seqlens_q, seq_lens_encoder, @@ -158,7 +224,8 @@ void gqa_rotary_qk_split_variable( kv_num_heads, seq_len, dim_head, - rope_3d); + rope_3d, + rms_norm_eps); } template GQARopeWriteCacheKernel( const paddle::Tensor &cache_batch_ids, const paddle::Tensor &cache_tile_ids, const paddle::Tensor &cache_num_blocks, + const paddle::optional &q_norm_weight, + const paddle::optional &k_norm_weight, const paddle::optional &cache_k_quant_scales, const paddle::optional &cache_v_quant_scales, const paddle::optional &cache_k_dequant_scales, @@ -1063,6 +1132,7 @@ std::vector GQARopeWriteCacheKernel( const paddle::optional &kv_signal_data, const int kv_token_num, const int max_seq_len, + const float rms_norm_eps, const std::string &cache_quant_type, const bool rope_3d) { typedef PDTraits traits_; @@ -1113,6 +1183,8 @@ std::vector GQARopeWriteCacheKernel( v.data(), qkv.data(), rotary_embs.data(), + q_norm_weight ? q_norm_weight.get().data() : nullptr, + k_norm_weight ? k_norm_weight.get().data() : nullptr, batch_id_per_token.data(), seq_lens_encoder.data(), seq_lens_decoder.data(), @@ -1125,6 +1197,7 @@ std::vector GQARopeWriteCacheKernel( rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2], head_dim, rope_3d, + rms_norm_eps, stream); if (token_num < kv_token_num) { @@ -1259,6 +1332,8 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) "cache_batch_ids", "cache_tile_ids_per_batch", "cache_num_blocks", + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight"), paddle::Optional("cache_k_quant_scales"), paddle::Optional("cache_v_quant_scales"), paddle::Optional("cache_k_dequant_scales"), @@ -1271,5 +1346,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache) {"value_cache", "value_cache_out"}}) .Attrs({"kv_token_num: int", "max_seq_len: int", - "cache_quant_type: std::string"}) + "rms_norm_eps: float", + "cache_quant_type: std::string", + "rope_3d: bool"}) .SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel)); diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 5deaa6bc9..93dbaad2d 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -178,6 +178,8 @@ std::vector GQARopeWriteCacheKernel( const paddle::Tensor& cache_batch_ids, const paddle::Tensor& cache_tile_ids, const paddle::Tensor& cache_num_blocks, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, const paddle::optional& cache_k_quant_scales, const paddle::optional& cache_v_quant_scales, const paddle::optional& cache_k_dequant_scales, @@ -187,6 +189,7 @@ std::vector GQARopeWriteCacheKernel( const paddle::optional& kv_signal_data, const int kv_token_num, const int max_seq_len, + const float rms_norm_eps, const std::string& cache_quant_type, const bool rope_3d); diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu index 16f1b223f..8c9978f25 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn.cu @@ -46,12 +46,11 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, const int kv_head_num, const int head_dim, const int max_seq_len, - const int max_enc_len_this_time, - const int max_dec_len_this_time) { + const int q_token_num, + const int k_token_num) { constexpr int kBlockM = 128; constexpr int kBlockN = 128; - const int batch_size = seq_len_encoder.dims()[0]; - + const int batch_size = cu_seq_k.dims()[0] - 1; Flash_mask_params params; memset(¶ms, 0, sizeof(Flash_mask_params)); @@ -63,8 +62,8 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input, params.seq_len_encoder = const_cast(seq_len_encoder.data()); params.head_num = head_num; params.kv_head_num = kv_head_num; - params.max_seq_len_q = max_enc_len_this_time; - params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time; + params.q_token_num = q_token_num; + params.k_token_num = k_token_num; params.batch_size = batch_size; params.gqa_group_size = head_num / kv_head_num; constexpr float kLog2e = 1.4426950408889634074; @@ -132,8 +131,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input, const int kv_head_num, const int head_dim, const int max_seq_len, - const int max_enc_len_this_time, - const int max_dec_len_this_time) { + const int q_token_num, + const int k_token_num) { if (q_input.dtype() == paddle::DataType::FLOAT16) { using T = phi::dtype::float16; DispatchFlashAttentionMask(q_input, @@ -148,8 +147,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input, kv_head_num, head_dim, max_seq_len, - max_enc_len_this_time, - max_dec_len_this_time); + q_token_num, + k_token_num); } else if (q_input.dtype() == paddle::DataType::BFLOAT16) { using T = phi::dtype::bfloat16; DispatchFlashAttentionMask(q_input, @@ -164,12 +163,12 @@ void FlashAttentionMask(const paddle::Tensor& q_input, kv_head_num, head_dim, max_seq_len, - max_enc_len_this_time, - max_dec_len_this_time); + q_token_num, + k_token_num); } } -PD_BUILD_STATIC_OP(flash_attention_mask) +PD_BUILD_STATIC_OP(flash_mask_attention) .Inputs({"q_input", "k_input", "v_input", @@ -182,8 +181,8 @@ PD_BUILD_STATIC_OP(flash_attention_mask) "kv_head_num: int", "head_dim: int", "max_seq_len: int", - "max_enc_len_this_time: int", - "max_dec_len_this_time: int"}) + "q_token_num: int", + "k_token_num: int"}) .Outputs({"out"}) .SetInplaceMap({{"attn_out", "out"}}) .SetKernelFn(PD_KERNEL(FlashAttentionMask)); diff --git a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp index d07e780fb..50e0d66c6 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/flash_mask_attn_kernel.hpp @@ -59,19 +59,26 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, auto &shared_storage = *reinterpret_cast(shared_memory); - __align__(16) __shared__ int mask[kBlockM]; + __align__(16) __shared__ int mask_end[kBlockM]; + __align__(16) __shared__ int mask_start[kBlockM]; const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int bidb = blockIdx.z; + if (data_params.seq_len_encoder[bidb] <= 0) { + return; + } if constexpr (NeedMask) { - const int *mask_this_batch = - data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM; + const int2 *mask_this_batch = + reinterpret_cast(data_params.mask) + + (data_params.cu_seq_q[bidb] + m_block * kBlockM); for (int i = threadIdx.x; i < kBlockM; i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) { - mask[i] = mask_this_batch[i]; + int2 mask_value = mask_this_batch[i]; + mask_start[i] = mask_value.x; + mask_end[i] = mask_value.y; } } @@ -119,7 +126,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, const int n_block_max = NeedMask - ? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN) + ? cute::ceil_div(mask_end[min(kBlockM - 1, real_seq - 1)], kBlockN) : min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q, kBlockN), cute::ceil_div(seq_len_k, kBlockN)); @@ -170,7 +177,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp, smem_pipe_read_v, tOrO, softmax, - mask, + mask_end, n_block_max, threadIdx.x - NumCopyThreads, m_block, @@ -207,18 +214,15 @@ void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) { typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments( {static_cast(params.q_ptr), - get_gmem_layout(params.max_seq_len_q * params.batch_size, - params.head_num), + get_gmem_layout(params.q_token_num, params.head_num), static_cast(params.k_ptr), - get_gmem_layout(params.max_seq_len_k * params.batch_size, - params.kv_head_num), + get_gmem_layout(params.k_token_num, params.kv_head_num), static_cast(params.v_ptr), - get_gmem_layout(params.max_seq_len_k * params.batch_size, - params.kv_head_num), + get_gmem_layout(params.k_token_num, params.kv_head_num), params.scale_softmax_log2}); int num_blocks_m = - cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM); + cutlass::ceil_div(params.q_token_num, Kernel_traits::kBlockM); num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); diff --git a/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h b/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h index 1a0e2cabd..612c1c8d2 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h +++ b/custom_ops/gpu_ops/flash_mask_attn/kernel_traits.h @@ -35,8 +35,8 @@ struct Flash_mask_params { int *seq_len_encoder; int head_num; int kv_head_num; - int max_seq_len_q; - int max_seq_len_k; + int q_token_num; + int k_token_num; int batch_size; int gqa_group_size; float scale_softmax_log2; diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index cbc6152aa..1ae0ef361 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -18,6 +18,7 @@ from .attention_selecter import get_attention_backend from .base_attention_backend import AttentionBackend from .block_multihead_attn_backend import BlockAttentionBackend from .flash_attn_backend import FlashAttentionBackend +from .flash_mask_attn_backend import FlashMaskAttentionBackend from .iluvatar_attn_backend import IluvatarAttnBackend from .mla_attention_backend import MLAAttentionBackend from .moba_attention_backend import PlasAttentionBackend @@ -36,4 +37,5 @@ __all__ = [ "BlockAttentionBackend", "Attention", "PlasAttentionBackend", + "FlashMaskAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py new file mode 100644 index 000000000..e2b6e4fd3 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -0,0 +1,316 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.ops import ( + flash_mask_attention, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + init_kv_signal_per_query, + init_signal_layerwise, + open_shm_and_get_meta_signal, + pre_cache_len_concat, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output +else: + merge_prefill_decode_output = None + +import os + + +@dataclass +class FlashMaskAttentionMetadata(AttentionMetadata): + """ + FlashAttentionMetadata + """ + + rotary_embs: Optional[paddle.Tensor] = None + block_tables: Optional[paddle.Tensor] = None + + cu_seqlens_q: paddle.Tensor = None + cu_seqlens_k: paddle.Tensor = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + + pre_cache_batch_ids = None + pre_cache_tile_ids_per_batch = None + pre_cache_num_blocks_cpu = None + kv_token_num_cpu = None + + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + + _fuse_kernel_compute_dtype: str = "bf16" + _dtype: paddle.dtype = paddle.bfloat16 + + max_len_tensor_cpu: paddle.Tensor = None + max_len_tensor_cpu_decoder: paddle.Tensor = None + + +class FlashMaskAttentionBackend(AttentionBackend): + """ + FlashAttentionBackend backend implementation + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: FlashMaskAttentionMetadata + flash_attn_func: callable = None + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ): + """ + FlashAttentionBackend __init__ + """ + super().__init__() + self.attention_metadata: FlashMaskAttentionMetadata = None + self.max_seq_len = fd_config.model_config.max_model_len + self.causal = getattr(fd_config.model_config, "causal", True) + + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads + self.head_dim = fd_config.model_config.head_dim + self.attn_outputsize_tp = self.num_heads * self.head_dim + self.block_size = fd_config.cache_config.block_size + self.num_layers: int = fd_config.model_config.num_hidden_layers + self.encoder_block_shape_q: int = encoder_block_shape_q + self.decoder_block_shape_q: int = decoder_block_shape_q + + self.speculative_method = fd_config.speculative_config.method + self.use_speculate = self.speculative_method is not None + self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"]) + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + + self.start_layer_index: int = fd_config.model_config.start_layer_index + + if fd_config.parallel_config.expert_parallel_rank is None: + fd_config.parallel_config.expert_parallel_rank = 0 + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + + self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr( + fd_config.model_config, "use_3d_rope", False + ) + if fd_config.speculative_config.model_type != "main": + self.rope_3d = False + self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768")) + self.zero_seq_enc_lens_for_decode = paddle.zeros( + shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32 + ) + + def get_attntion_meta(self): + """get_attntion_meta""" + return self.attention_metadata + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ): + """ + Calculate kv cache shape + """ + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + key_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + value_cache_shape = [ + max_num_blocks, + self.kv_num_heads, + self.block_size, + self.head_dim // 2, + ] + return key_cache_shape, value_cache_shape + + def init_attention_metadata(self, forward_meta: ForwardMeta): + metadata = FlashMaskAttentionMetadata() + metadata.cu_seqlens_q = forward_meta.cu_seqlens_q + metadata.rotary_embs = forward_meta.rotary_embs + metadata.block_tables = forward_meta.block_tables + get_block_shape_and_split_kv_block( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.decoder_num_blocks_device, + forward_meta.decoder_chunk_size_device, + forward_meta.max_len_tensor_cpu, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.group_size, + self.block_size, + ) + + ( + metadata.cu_seqlens_k, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + metadata.kv_token_num_cpu, + ) = pre_cache_len_concat( + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.max_len_tensor_cpu[2], + self.block_size, + ) + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_metadata = open_shm_and_get_meta_signal( + self.rank, int(self.device_id), self.keep_pd_step_flag + ) + + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + + metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu + metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu) + metadata.max_len_tensor_cpu_decoder[1] = 0 + + self.attention_metadata = metadata + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ): + metadata = self.attention_metadata + + if self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index, + ) + + if metadata.max_len_tensor_cpu[1] > 0: + res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype) + q, k, v, _ = gqa_rope_write_cache( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.rotary_embs, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + metadata.block_tables, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + metadata.pre_cache_batch_ids, + metadata.pre_cache_tile_ids_per_batch, + metadata.pre_cache_num_blocks_cpu, + getattr(layer, "q_norm_weight", None), + getattr(layer, "k_norm_weight", None), + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + metadata.kv_token_num_cpu[0].item(), + self.max_seq_len, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + self.rope_3d, + ) + + flash_mask_attention( + q, + k, + v, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + forward_meta.seq_lens_encoder, + res_encoder, + forward_meta.attn_mask_offsets, + self.num_heads, + self.kv_num_heads, + self.head_dim, + self.max_seq_len, + q.shape[0], + k.shape[0], + ) + return res_encoder + else: + raise NotImplementedError("FlashMaskAttentionBackend is not supported for decode.") diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index caf8bcb9b..064155d2c 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,7 @@ """ from .append_attention import append_attention, append_attention_with_output +from .flash_mask_attention import flash_mask_attention from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block from .gqa_rope_write_cache import gqa_rope_write_cache from .init_kv_signal_per_query import init_kv_signal_per_query @@ -31,4 +32,5 @@ __all__ = [ "gqa_rope_write_cache", "pre_cache_len_concat", "init_kv_signal_per_query", + "flash_mask_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/flash_mask_attention.py b/fastdeploy/model_executor/layers/attention/ops/flash_mask_attention.py new file mode 100644 index 000000000..a13989b78 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/flash_mask_attention.py @@ -0,0 +1,60 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + + +def flash_mask_attention( + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + cu_seqlens_k: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + attn_out: paddle.Tensor, + attn_mask_offsets: Optional[paddle.Tensor] = None, + num_heads: int = 0, + kv_num_heads: int = 0, + head_dim: int = 128, + max_seq_len: int = 0, + q_token_num: int = 0, + kv_token_num: int = 0, +): + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import flash_mask_attention + + flash_mask_attention( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seq_lens_encoder, + attn_out, + attn_mask_offsets, + num_heads, + kv_num_heads, + head_dim, + max_seq_len, + q_token_num, + kv_token_num, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py index 9aac80df3..670fa65f3 100644 --- a/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py +++ b/fastdeploy/model_executor/layers/attention/ops/gqa_rope_write_cache.py @@ -39,6 +39,8 @@ def gqa_rope_write_cache( cache_batch_ids: paddle.Tensor, cache_tile_ids_per_batch: paddle.Tensor, cache_num_blocks: paddle.Tensor, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, cache_k_quant_scales: Optional[paddle.Tensor] = None, cache_v_quant_scales: Optional[paddle.Tensor] = None, cache_k_dequant_scales: Optional[paddle.Tensor] = None, @@ -48,6 +50,7 @@ def gqa_rope_write_cache( kv_signal_data: Optional[paddle.Tensor] = None, kv_token_num: int = 1, max_seq_len: int = 0, + rms_norm_eps: float = 1e-6, cache_quant_type: str = "none", rope_3d: bool = False, ): @@ -72,6 +75,8 @@ def gqa_rope_write_cache( cache_batch_ids, cache_tile_ids_per_batch, cache_num_blocks, + q_norm_weight, + k_norm_weight, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, @@ -81,6 +86,7 @@ def gqa_rope_write_cache( kv_signal_data, kv_token_num, max_seq_len, + rms_norm_eps, cache_quant_type, rope_3d, ) diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index 16251c1c1..9db9ebf77 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -28,6 +28,7 @@ class _Backend(enum.Enum): BLOCK_ATTN = enum.auto() PLAS_ATTN = enum.auto() HPU_ATTN = enum.auto() + FLASH_MASK_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index 9720e7ace..8d0d559fe 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -67,6 +67,9 @@ class CUDAPlatform(Platform): elif selected_backend == _Backend.PLAS_ATTN: logger.info("Using PLAS ATTN backend.") return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend" + elif selected_backend == _Backend.FLASH_MASK_ATTN: + logger.info("Using FLASH MASK ATTN backend.") + return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index 2ada04527..1d222c8d7 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -19,7 +19,7 @@ import unittest import numpy as np import paddle -from fastdeploy.model_executor.ops.gpu import flash_attention_mask +from fastdeploy.model_executor.ops.gpu import flash_mask_attention class TestFlashMaskAttention(unittest.TestCase): @@ -27,7 +27,7 @@ class TestFlashMaskAttention(unittest.TestCase): self.bsz = 1 self.num_head = 8 self.num_kv_head = 1 - self.q_seq_len = 1024 + self.q_seq_len = 888 self.k_seq_len = 1024 self.head_dim = 128 np.random.seed(self.q_seq_len) @@ -71,7 +71,7 @@ class TestFlashMaskAttention(unittest.TestCase): v_input_pad[0 : v_input.shape[0]] = v_input mask = paddle.to_tensor(mask).astype("int32") - flash_attention_mask( + flash_mask_attention( q_input, k_input, v_input_pad, @@ -88,7 +88,7 @@ class TestFlashMaskAttention(unittest.TestCase): int(k_input.shape[0]), ) - def test_flash_attention_mask(self): + def test_flash_mask_attention(self): q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim)) k_input = np.random.normal( 0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim) @@ -104,8 +104,8 @@ class TestFlashMaskAttention(unittest.TestCase): mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len - naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask) - paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16") + naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask).transpose([0, 2, 1, 3]) + paddle_attn_out = paddle.zeros(q_input.shape, dtype="bfloat16") self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask) max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())