mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)
* DSK_OPT_01 * update FA3
This commit is contained in:
@@ -24,6 +24,11 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block,
|
||||
init_kv_signal_per_query,
|
||||
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: MLAAttentionMetadata
|
||||
flash_attn_func: callable = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -148,6 +154,22 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
if self.flash_attn_func is None:
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
is_current_sm_supported = cc >= 90
|
||||
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
||||
if is_current_sm_supported and is_paddle_supported:
|
||||
self.flash_attn_func = flash_attention_v3_varlen
|
||||
print("The current platform supports Flash Attention V3.")
|
||||
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
|
||||
else:
|
||||
self.flash_attn_func = flash_attn_unpadded
|
||||
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
|
||||
print(
|
||||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||
)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
@@ -269,7 +291,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# Flash注意力计算
|
||||
fmha_out = flash_attn_unpadded(
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -277,9 +299,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
@@ -418,7 +439,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = flash_attn_unpadded(
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -426,9 +447,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
Reference in New Issue
Block a user