【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)

* DSK_OPT_01

* update FA3
This commit is contained in:
AIbin
2025-08-19 10:42:42 +08:00
committed by GitHub
parent c95b3395e9
commit beec24fd89
2 changed files with 49 additions and 47 deletions

View File

@@ -24,6 +24,11 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
import paddle
from paddle.nn.functional.flash_attention import flash_attn_unpadded
try:
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
except:
flash_attention_v3_varlen = None
from fastdeploy.model_executor.layers.attention.ops import (
get_block_shape_and_split_kv_block,
init_kv_signal_per_query,
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: MLAAttentionMetadata
flash_attn_func: callable = None
def __init__(
self,
@@ -148,6 +154,22 @@ class MLAAttentionBackend(AttentionBackend):
self.rank, self.device_id = init_rank_and_device_id(fd_config)
if self.flash_attn_func is None:
prop = paddle.device.cuda.get_device_properties()
cc = prop.major * 10 + prop.minor
is_current_sm_supported = cc >= 90
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
if is_current_sm_supported and is_paddle_supported:
self.flash_attn_func = flash_attention_v3_varlen
print("The current platform supports Flash Attention V3.")
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
else:
self.flash_attn_func = flash_attn_unpadded
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
print(
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
)
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
@@ -269,7 +291,7 @@ class MLAAttentionBackend(AttentionBackend):
)
# Flash注意力计算
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
@@ -277,9 +299,8 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out
@@ -418,7 +439,7 @@ class MLAAttentionBackend(AttentionBackend):
)
# FA
fmha_out = flash_attn_unpadded(
fmha_out = self.flash_attn_func(
q,
k,
v,
@@ -426,9 +447,8 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.cu_seqlens_k,
metadata.max_enc_len_this_time,
metadata.max_enc_len_this_time,
self.attn_softmax_scale,
causal=True,
training=False,
causal=self.causal,
**self.flash_attn_kwargs,
)[0]
return fmha_out

View File

@@ -316,30 +316,23 @@ class DeepseekV3MLAAttention(nn.Layer):
mask_encoder_batch: paddle.Tensor,
):
""" """
layernorm_out = hidden_states
fmha_out = paddle.zeros(
shape=[
layernorm_out.shape[0],
self.num_attention_heads_tp * self.v_head_dim,
],
dtype=layernorm_out.dtype,
)
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
query = self.q_a_proj(layernorm_out)
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
fmha_out = None
query = self.q_a_proj(hidden_states)
query = self.q_a_layernorm(query)
query = self.q_b_proj(query)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
key_value = self.kv_b_proj(compressed_kv)
key_value = key_value.reshape(
[
@@ -371,23 +364,9 @@ class DeepseekV3MLAAttention(nn.Layer):
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out = fmha_out + fmha_out_prefill
fmha_out = fmha_out_prefill
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
query = self.q_a_proj(layernorm_out)
query = self.q_a_layernorm(query)
ln_out_or_q_c = query
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
compressed_kv = self.kv_a_layernorm(compressed_kv)
query = self.q_b_proj(ln_out_or_q_c)
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
@@ -416,6 +395,9 @@ class DeepseekV3MLAAttention(nn.Layer):
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
)
if fmha_out is None:
fmha_out = fmha_out_decode
else:
fmha_out = fmha_out + fmha_out_decode
output = self.o_proj(fmha_out)