mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[feature] Support FA2 (#3009)
This commit is contained in:
@@ -20,6 +20,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||||
@@ -91,6 +92,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||||
attention_metadata: FlashAttentionMetadata
|
attention_metadata: FlashAttentionMetadata
|
||||||
|
flash_attn_func: callable = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -110,7 +112,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.kv_num_heads = kv_num_heads
|
self.kv_num_heads = kv_num_heads
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_dim = fd_config.model_config.head_dim
|
self.head_dim = fd_config.model_config.head_dim
|
||||||
self.hidden_size = self.num_heads * self.head_dim
|
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
||||||
self.block_size = fd_config.parallel_config.block_size
|
self.block_size = fd_config.parallel_config.block_size
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
|
|
||||||
@@ -129,6 +131,22 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||||
|
|
||||||
|
if self.flash_attn_func is None:
|
||||||
|
prop = paddle.device.cuda.get_device_properties()
|
||||||
|
cc = prop.major * 10 + prop.minor
|
||||||
|
is_current_sm_supported = cc >= 90
|
||||||
|
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
||||||
|
if is_current_sm_supported and is_paddle_supported:
|
||||||
|
self.flash_attn_func = flash_attention_v3_varlen
|
||||||
|
print("The current platform supports Flash Attention V3.")
|
||||||
|
self.flash_attn_kwargs = {}
|
||||||
|
else:
|
||||||
|
self.flash_attn_func = flash_attn_unpadded
|
||||||
|
self.flash_attn_kwargs = {"scale": self.head_dim**-0.5, "training": False}
|
||||||
|
print(
|
||||||
|
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||||
|
)
|
||||||
|
|
||||||
def get_attntion_meta(self):
|
def get_attntion_meta(self):
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
return self.attention_metadata
|
return self.attention_metadata
|
||||||
@@ -266,7 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
getattr(layer, "cache_quant_type_str", "none"),
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
)
|
)
|
||||||
res = flash_attention_v3_varlen(
|
|
||||||
|
res = self.flash_attn_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@@ -275,5 +294,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
max_seqlen_q=metadata.set_max_lengths[0],
|
max_seqlen_q=metadata.set_max_lengths[0],
|
||||||
max_seqlen_k=metadata.set_max_lengths[3],
|
max_seqlen_k=metadata.set_max_lengths[3],
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)[0].reshape([-1, self.hidden_size])
|
**self.flash_attn_kwargs,
|
||||||
|
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||||
return res
|
return res
|
||||||
|
Reference in New Issue
Block a user