From 332154f504ff1498f482cfcb51e48910bbe09fc0 Mon Sep 17 00:00:00 2001 From: chen <103103266+ckl117@users.noreply.github.com> Date: Fri, 25 Jul 2025 14:09:00 +0800 Subject: [PATCH] [feature] Support FA2 (#3009) --- .../layers/attention/flash_attn_backend.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index d5e367fe0..be8234cf2 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -20,6 +20,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, List, Optional import paddle +from paddle.nn.functional.flash_attention import flash_attn_unpadded try: from paddle.nn.functional.flash_attention import flash_attention_v3_varlen @@ -91,6 +92,7 @@ class FlashAttentionBackend(AttentionBackend): __infer_dynamic_dims_fields__ = ["attention_metadata"] attention_metadata: FlashAttentionMetadata + flash_attn_func: callable = None def __init__( self, @@ -110,7 +112,7 @@ class FlashAttentionBackend(AttentionBackend): self.kv_num_heads = kv_num_heads self.num_heads = num_heads self.head_dim = fd_config.model_config.head_dim - self.hidden_size = self.num_heads * self.head_dim + self.attn_outputsize_tp = self.num_heads * self.head_dim self.block_size = fd_config.parallel_config.block_size self.num_layers: int = fd_config.model_config.num_hidden_layers @@ -129,6 +131,22 @@ class FlashAttentionBackend(AttentionBackend): self.rank, self.device_id = init_rank_and_device_id(fd_config) + if self.flash_attn_func is None: + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + is_current_sm_supported = cc >= 90 + is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs()) + if is_current_sm_supported and is_paddle_supported: + self.flash_attn_func = flash_attention_v3_varlen + print("The current platform supports Flash Attention V3.") + self.flash_attn_kwargs = {} + else: + self.flash_attn_func = flash_attn_unpadded + self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False} + print( + "The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead." + ) + def get_attntion_meta(self): """get_attntion_meta""" return self.attention_metadata @@ -266,7 +284,8 @@ class FlashAttentionBackend(AttentionBackend): self.max_seq_len, getattr(layer, "cache_quant_type_str", "none"), ) - res = flash_attention_v3_varlen( + + res = self.flash_attn_func( q, k, v, @@ -275,5 +294,6 @@ class FlashAttentionBackend(AttentionBackend): max_seqlen_q=metadata.set_max_lengths[0], max_seqlen_k=metadata.set_max_lengths[3], causal=self.causal, - )[0].reshape([-1, self.hidden_size]) + **self.flash_attn_kwargs, + )[0].reshape([-1, self.attn_outputsize_tp]) return res