【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)

* DSK_OPT_01

* update FA3
This commit is contained in:
AIbin
2025-08-19 10:42:42 +08:00
committed by GitHub
parent c95b3395e9
commit beec24fd89
2 changed files with 49 additions and 47 deletions

View File

@@ -24,6 +24,11 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded
try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block,
init_kv_signal_per_query,
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: MLAAttentionMetadata
flash_attn_func: callable = None
def __init__(
self,
@@ -148,6 +154,22 @@ class MLAAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
if is_current_sm_supported and is_paddle_supported:
self.flash_attn_func = flash_attention_v3_varlen
print("The current platform supports Flash Attention V3.")
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
else:
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
@@ -269,7 +291,7 @@ class MLAAttentionBackend(AttentionBackend):
)
# Flash注意力计算
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
@@ -277,9 +299,8 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out
@@ -418,7 +439,7 @@ class MLAAttentionBackend(AttentionBackend):
)
# FA
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
@@ -426,9 +447,8 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out