mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-01 14:52:33 +08:00
【Inference Optimize】DeepSeek-v3 model inference performance optimization (#3455)
* DSK_OPT_01 * update FA3
This commit is contained in:
@@ -24,6 +24,11 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
import paddle
|
||||
from paddle.nn.functional.flash_attention import flash_attn_unpadded
|
||||
|
||||
try:
|
||||
from paddle.nn.functional.flash_attention import flash_attention_v3_varlen
|
||||
except:
|
||||
flash_attention_v3_varlen = None
|
||||
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
get_block_shape_and_split_kv_block,
|
||||
init_kv_signal_per_query,
|
||||
@@ -92,6 +97,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: MLAAttentionMetadata
|
||||
flash_attn_func: callable = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -148,6 +154,22 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
if self.flash_attn_func is None:
|
||||
prop = paddle.device.cuda.get_device_properties()
|
||||
cc = prop.major * 10 + prop.minor
|
||||
is_current_sm_supported = cc >= 90
|
||||
is_paddle_supported = any(num >= 90 for num in paddle.version.cuda_archs())
|
||||
if is_current_sm_supported and is_paddle_supported:
|
||||
self.flash_attn_func = flash_attention_v3_varlen
|
||||
print("The current platform supports Flash Attention V3.")
|
||||
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
|
||||
else:
|
||||
self.flash_attn_func = flash_attn_unpadded
|
||||
self.flash_attn_kwargs = {"scale": self.attn_softmax_scale, "training": False}
|
||||
print(
|
||||
"The current platform does not support Flash Attention V3, so Flash Attention V2 will be used instead."
|
||||
)
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
@@ -269,7 +291,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# Flash注意力计算
|
||||
fmha_out = flash_attn_unpadded(
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -277,9 +299,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
@@ -418,7 +439,7 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
# FA
|
||||
fmha_out = flash_attn_unpadded(
|
||||
fmha_out = self.flash_attn_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -426,9 +447,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
forward_meta.cu_seqlens_k,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_enc_len_this_time,
|
||||
self.attn_softmax_scale,
|
||||
causal=True,
|
||||
training=False,
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0]
|
||||
|
||||
return fmha_out
|
||||
|
@@ -316,30 +316,23 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
mask_encoder_batch: paddle.Tensor,
|
||||
):
|
||||
""" """
|
||||
layernorm_out = hidden_states
|
||||
fmha_out = paddle.zeros(
|
||||
shape=[
|
||||
layernorm_out.shape[0],
|
||||
self.num_attention_heads_tp * self.v_head_dim,
|
||||
],
|
||||
dtype=layernorm_out.dtype,
|
||||
)
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
# NOTE: (changwenbin) Bring out the public calculation in PD MIX to avoid repeated calculation.
|
||||
fmha_out = None
|
||||
query = self.q_a_proj(hidden_states)
|
||||
query = self.q_a_layernorm(query)
|
||||
query = self.q_b_proj(query)
|
||||
|
||||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[1]: # max_enc_len_this_time
|
||||
key_value = self.kv_b_proj(compressed_kv)
|
||||
key_value = key_value.reshape(
|
||||
[
|
||||
@@ -371,23 +364,9 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
fmha_out_prefill = fmha_out_prefill.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
|
||||
|
||||
fmha_out = fmha_out + fmha_out_prefill
|
||||
fmha_out = fmha_out_prefill
|
||||
|
||||
if forward_meta.max_len_tensor_cpu[2]: # max_dec_len_this_time
|
||||
query = self.q_a_proj(layernorm_out)
|
||||
query = self.q_a_layernorm(query)
|
||||
ln_out_or_q_c = query
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(layernorm_out)
|
||||
compressed_kv, key_pe = compressed_kv.split([self.kv_lora_rank, self.qk_rope_head_dim], axis=-1)
|
||||
key_pe = key_pe.reshape([-1, 1, self.qk_rope_head_dim])
|
||||
compressed_kv = self.kv_a_layernorm(compressed_kv)
|
||||
|
||||
query = self.q_b_proj(ln_out_or_q_c)
|
||||
query = query.reshape([-1, self.num_attention_heads_tp, self.qk_head_dim])
|
||||
|
||||
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
|
||||
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
|
||||
|
||||
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]), proj_type="k").transpose([1, 0, 2])
|
||||
|
||||
q_input = paddle.concat([q_nope_out, query_pe], axis=-1)
|
||||
@@ -416,6 +395,9 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
.transpose([1, 0, 2])
|
||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
)
|
||||
if fmha_out is None:
|
||||
fmha_out = fmha_out_decode
|
||||
else:
|
||||
fmha_out = fmha_out + fmha_out_decode
|
||||
|
||||
output = self.o_proj(fmha_out)
|
||||
|
Reference in New Issue
Block a user